Commit 14d9f3f6 authored by ashish's avatar ashish

Merge pull request #9 from ashish-gore/modelcache

Support for count in response, misc fixes
parents 96e1bec4 75fcac6a
...@@ -30,7 +30,6 @@ class APIModelBuilder(object): ...@@ -30,7 +30,6 @@ class APIModelBuilder(object):
router = routers.SimpleRouter() router = routers.SimpleRouter()
for model in self.models: for model in self.models:
manager = type( manager = type(
str('{}Manager'.format(model.__name__)), str('{}Manager'.format(model.__name__)),
self.base_managers + (AlchemyModelManager,), self.base_managers + (AlchemyModelManager,),
......
...@@ -28,7 +28,7 @@ class AlchemyModelManager(object): ...@@ -28,7 +28,7 @@ class AlchemyModelManager(object):
try: try:
pk = primary_key(self.cls) pk = primary_key(self.cls)
except KeyNotFoundException: except KeyNotFoundException:
return list() pk = None
filter_dict = dict() filter_dict = dict()
...@@ -48,14 +48,28 @@ class AlchemyModelManager(object): ...@@ -48,14 +48,28 @@ class AlchemyModelManager(object):
query_pks[key] = other_pks[key] query_pks[key] = other_pks[key]
query_pks.update(filter_dict) query_pks.update(filter_dict)
queryset = self.session.query(self.cls.__dict__[pk]).filter_by(
**query_pks).all() if pk:
queryset = self.session.query(self.cls.__dict__[pk]).filter_by(
**query_pks).all()
else:
queryset = self.session.query(self.cls).filter_by(
**query_pks).all()
else: else:
if filter_dict: if filter_dict:
queryset = self.session.query(self.cls.__dict__[pk]).filter_by( if pk:
**filter_dict).all() queryset = self.session.query(
self.cls.__dict__[pk]).filter_by(
**filter_dict).all()
else:
queryset = self.session.query(self.cls).filter_by(
**filter_dict).all()
else: else:
queryset = self.session.query(self.cls.__dict__[pk]).all() if pk:
queryset = self.session.query(self.cls.__dict__[pk]).all()
else:
# Limit to 1000 rows, this is worst case scenario
queryset = self.session.query(self.cls).limit(1000).all()
return queryset return queryset
......
...@@ -4,10 +4,10 @@ SQLALchemy and DRF fields to serialize/deserialize objects ...@@ -4,10 +4,10 @@ SQLALchemy and DRF fields to serialize/deserialize objects
''' '''
from rest_framework import serializers from rest_framework import serializers
from rest_framework.fields import (CharField, IntegerField, DateTimeField, from rest_framework.fields import (CharField, IntegerField, DateTimeField,
FloatField, BooleanField) FloatField, BooleanField, DecimalField)
from sqlalchemy.types import (String, INTEGER, SMALLINT, BIGINT, VARCHAR, from sqlalchemy.types import (String, INTEGER, SMALLINT, BIGINT, VARCHAR,
CHAR, TIMESTAMP, DATE, Float, BigInteger, CHAR, TIMESTAMP, DATE, Float, BigInteger,
Numeric, DateTime, Boolean, CLOB) Numeric, DateTime, Boolean, CLOB, DECIMAL)
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from djangorest_alchemy.fields import AlchemyRelatedField, AlchemyUriField from djangorest_alchemy.fields import AlchemyRelatedField, AlchemyUriField
# inspect introduced in 0.8 # inspect introduced in 0.8
...@@ -36,12 +36,15 @@ class AlchemyModelSerializer(serializers.Serializer): ...@@ -36,12 +36,15 @@ class AlchemyModelSerializer(serializers.Serializer):
Numeric: IntegerField, Numeric: IntegerField,
DateTime: DateTimeField, DateTime: DateTimeField,
Boolean: BooleanField, Boolean: BooleanField,
CLOB: CharField CLOB: CharField,
DECIMAL: DecimalField,
} }
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
assert "model_class" in kwargs, \ assert "model_class" in kwargs, \
"model_class should be passed" "model_class should be passed"
assert 'request' in kwargs['context'], \
"Context must contain request object"
self.cls = kwargs.pop('model_class') self.cls = kwargs.pop('model_class')
super(AlchemyModelSerializer, self).__init__(*args, **kwargs) super(AlchemyModelSerializer, self).__init__(*args, **kwargs)
...@@ -52,15 +55,14 @@ class AlchemyModelSerializer(serializers.Serializer): ...@@ -52,15 +55,14 @@ class AlchemyModelSerializer(serializers.Serializer):
mapper = class_mapper(self.cls.__class__) mapper = class_mapper(self.cls.__class__)
r = self.context['request']
try: try:
# URI field for get pk field # URI field for get pk field
pk_field = primary_key(self.cls.__class__) pk_field = primary_key(self.cls.__class__)
ret['href'] = AlchemyUriField(source=pk_field,
path=r.build_absolute_uri(r.path))
except KeyNotFoundException: except KeyNotFoundException:
return ret pass
r = self.context['request']
ret['href'] = AlchemyUriField(source=pk_field,
path=r.build_absolute_uri(r.path))
# Get all the Column fields # Get all the Column fields
for col_prop in mapper.iterate_properties: for col_prop in mapper.iterate_properties:
...@@ -94,12 +96,12 @@ class AlchemyListSerializer(AlchemyModelSerializer): ...@@ -94,12 +96,12 @@ class AlchemyListSerializer(AlchemyModelSerializer):
try: try:
# URI field for get pk field # URI field for get pk field
pk_field = primary_key(self.cls.__class__) pk_field = primary_key(self.cls.__class__)
except KeyNotFoundException:
return ret
request = self.context['request'] request = self.context['request']
ret["href"] = AlchemyUriField(source=pk_field, ret["href"] = AlchemyUriField(source=pk_field,
path=request.build_absolute_uri( path=request.build_absolute_uri
request.path)) (request.path))
except KeyNotFoundException:
return super(AlchemyListSerializer, self).get_default_fields()
return ret return ret
...@@ -22,6 +22,8 @@ from rest_framework.decorators import action ...@@ -22,6 +22,8 @@ from rest_framework.decorators import action
RESULTS_KEY = "results" RESULTS_KEY = "results"
COUNT_KEY = "count"
PAGE_KEY = "page"
class PrimaryKeyMixin(object): class PrimaryKeyMixin(object):
...@@ -104,6 +106,8 @@ class TestAlchemyViewSetIntegration(TestCase): ...@@ -104,6 +106,8 @@ class TestAlchemyViewSetIntegration(TestCase):
self.assertTrue(resp.status_code is status.HTTP_200_OK) self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(type(resp.data) is dict) self.assertTrue(type(resp.data) is dict)
self.assertTrue(len(resp.data[RESULTS_KEY]) == 1) self.assertTrue(len(resp.data[RESULTS_KEY]) == 1)
self.assertTrue(resp.data[COUNT_KEY] == 1)
self.assertTrue(resp.data[PAGE_KEY] == 25)
def test_decl_retrieve(self): def test_decl_retrieve(self):
resp = self.client.get('/api/declmodels/1/') resp = self.client.get('/api/declmodels/1/')
......
...@@ -84,6 +84,8 @@ class AlchemyModelViewSet(MultipleObjectMixin, ManagerMixin, viewsets.ViewSet): ...@@ -84,6 +84,8 @@ class AlchemyModelViewSet(MultipleObjectMixin, ManagerMixin, viewsets.ViewSet):
queryset = mgr.list(other_pks=self.get_other_pks(request), queryset = mgr.list(other_pks=self.get_other_pks(request),
filters=request.QUERY_PARAMS) filters=request.QUERY_PARAMS)
count = len(queryset)
if self.paginate_by: if self.paginate_by:
try: try:
queryset = self.get_page(queryset) queryset = self.get_page(queryset)
...@@ -94,7 +96,9 @@ class AlchemyModelViewSet(MultipleObjectMixin, ManagerMixin, viewsets.ViewSet): ...@@ -94,7 +96,9 @@ class AlchemyModelViewSet(MultipleObjectMixin, ManagerMixin, viewsets.ViewSet):
mgr.model_class(), mgr.model_class(),
{'request': request}) {'request': request})
return Response({"results": serializer.data}) return Response({"count": count,
"page": self.paginate_by,
"results": serializer.data})
def retrieve(self, request, **kwargs): def retrieve(self, request, **kwargs):
''' '''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment