drf3 changes with test case changes

parent 6fe5dd34
{
"groups": [
{
"type": "stdlib"
},
{
"type": "remainder"
},
{
"type": "packages",
"packages": [
"djangorest_alchemy"
]
},
{
"type": "local"
}
]
}
......@@ -2,9 +2,9 @@
API Builder
Build dynamic API based on the provided SQLAlchemy model
"""
from managers import AlchemyModelManager
from viewsets import AlchemyModelViewSet
from routers import ReadOnlyRouter
from .managers import AlchemyModelManager
from .routers import ReadOnlyRouter
from .viewsets import AlchemyModelViewSet
class APIModelBuilder(object):
......
......@@ -2,7 +2,8 @@
Relationship field
'''
from rest_framework.relations import RelatedField
from djangorest_alchemy.inspector import primary_key, KeyNotFoundException
from djangorest_alchemy.inspector import KeyNotFoundException, primary_key
class AlchemyRelatedField(RelatedField):
......@@ -11,16 +12,18 @@ class AlchemyRelatedField(RelatedField):
self.parent_path = kwargs.pop('path')
super(AlchemyRelatedField, self).__init__(*args, **kwargs)
def to_native(self, obj):
model_name = obj.__class__.__name__.lower()
def to_representation(self, value):
model_name = value.__class__.__name__.lower()
# Try to get pk field
# if not found, it's a child model
# dependent on parent keys
try:
pk_field = primary_key(obj.__class__)
pk_val = getattr(obj, pk_field, None)
return self.parent_path + model_name + 's/' + str(pk_val) + '/'
pk_field = primary_key(value.__class__)
pk_val = getattr(value, pk_field, None)
return ('{parent}{model}s/{pk}/'
''.format(parent=self.parent_path, model=model_name, pk=pk_val))
except KeyNotFoundException:
# Use actual model name
return self.parent_path + model_name + 's/'
......@@ -31,5 +34,5 @@ class AlchemyUriField(RelatedField):
self.parent_path = kwargs.pop('path')
super(AlchemyUriField, self).__init__(*args, **kwargs)
def to_native(self, obj):
return self.parent_path + str(obj) + '/'
def to_representation(self, value):
return '{parent}{pk}/'.format(parent=self.parent_path, pk=value)
......@@ -8,6 +8,10 @@ class KeyNotFoundException(Exception):
"""Primary key not found exception"""
def public_vars(cls):
return {k: v for k, v in vars(cls).items() if not k.startswith('_')}
def class_keys(cls):
"""This is a utility function to get the attribute names for
the primary keys of a class
......@@ -16,7 +20,7 @@ def class_keys(cls):
# >>> ('dealer_code', 'deal_jacket_id', 'deal_id')
"""
reverse_map = {}
for name, attr in cls.__dict__.items():
for name, attr in public_vars(cls).items():
try:
reverse_map[attr.property.columns[0].name] = name
except:
......@@ -31,17 +35,15 @@ def primary_key(cls):
of the class. In case of multiple primary keys,
use the <classname>_id convention
"""
has_multiple_pk = len(class_keys(cls)) > 1
keys = class_keys(cls)
if has_multiple_pk:
if len(keys) > 1:
# guess the pk
pk = cls.__name__.lower() + '_id'
else:
for key in class_keys(cls):
pk = key
break
pk = next(iter(keys), None)
if not pk in cls.__dict__:
if pk not in cls.__dict__:
# could not find pk field in class, now check
# whether it has been explicitly specified
if 'pk_field' in cls.__dict__:
......
......@@ -3,7 +3,7 @@ Base for interfacing with SQLAlchemy
Provides the necessary plumbing for CRUD
using SA session
'''
from inspector import class_keys, primary_key, KeyNotFoundException
from .inspector import KeyNotFoundException, class_keys, primary_key
class AlchemyModelManager(object):
......@@ -33,7 +33,7 @@ class AlchemyModelManager(object):
filter_dict = dict()
if filters:
filter_dict = {k: v for k, v in filters.iteritems()}
filter_dict = {k: v for k, v in filters.items()}
filter_dict.pop('format', None)
filter_dict.pop('page', None)
filter_dict.pop('sort_by', None)
......@@ -85,14 +85,13 @@ class AlchemyModelManager(object):
if not other_pks:
newargs = list(pks)
else:
newargs = list()
for key in class_keys(self.cls):
if other_pks and key in other_pks:
newargs.append(other_pks[key])
newargs = [
other_pks[key]
for key in class_keys(self.cls) if key in other_pks
]
# Confirm this logic works!!!
# will the order be correct if we just append?
for pk in reversed(pks):
newargs.append(pk)
newargs.extend(list(reversed(pks)))
return self.session.query(self.cls).get(newargs)
# -*- coding: utf-8 -*-
from django.core.paginator import Paginator, InvalidPage, Page
from rest_framework.response import Response
from rest_framework import status
import six
from django.core.paginator import InvalidPage, Page, Paginator
from rest_framework import status
from rest_framework.response import Response
STATUS_CODES = {
'created': status.HTTP_201_CREATED,
......@@ -82,7 +83,7 @@ class MultipleObjectMixin(object):
def make_action_method(name, methods, **kwargs):
def func(self, request, pk=None, **kwargs):
assert hasattr(request, 'DATA'), 'request object must have DATA'
assert hasattr(request, 'data'), 'request object must have data'
' attribute'
assert hasattr(self, 'manager_class'), 'viewset must have'
' manager_class defined'
......@@ -92,7 +93,7 @@ def make_action_method(name, methods, **kwargs):
mgr = self.manager_factory(context={'request': request})
mgr_method = getattr(mgr, name)
resp = mgr_method(request.DATA, pk, **kwargs)
resp = mgr_method(request.data, pk, **kwargs)
# no response returned back, assume everything is fine
if not resp:
......@@ -117,7 +118,7 @@ class ManagerMeta(type):
if 'manager_class' in attrs:
mgr_class = attrs['manager_class']
if hasattr(mgr_class, 'action_methods'):
for mname, methods in mgr_class.action_methods.iteritems():
for mname, methods in mgr_class.action_methods.items():
attrs[mname] = make_action_method(mname.lower(), methods)
return super(ManagerMeta, cls).__new__(cls, name, bases, attrs)
......
......@@ -2,8 +2,8 @@ import importlib
import inspect
import itertools
import os
import six
import six
from django.conf import settings
......
from rest_framework.routers import DefaultRouter
from rest_framework.routers import Route
from rest_framework.routers import DefaultRouter, Route
class ReadOnlyRouter(DefaultRouter):
......
......@@ -2,19 +2,39 @@
Base AlchemyModelSerializer which provides the mapping between
SQLALchemy and DRF fields to serialize/deserialize objects
'''
from rest_framework import serializers
from rest_framework.fields import (CharField, IntegerField, DateTimeField,
FloatField, BooleanField, DecimalField)
from sqlalchemy.types import (String, INTEGER, SMALLINT, BIGINT, VARCHAR,
CHAR, TIMESTAMP, DATE, Float, BigInteger,
Numeric, DateTime, Boolean, CLOB, DECIMAL)
from django.utils.datastructures import SortedDict
from djangorest_alchemy.fields import AlchemyRelatedField, AlchemyUriField
# inspect introduced in 0.8
#from sqlalchemy import inspect
from rest_framework import serializers
from rest_framework.fields import (
BooleanField,
CharField,
DateTimeField,
DecimalField,
FloatField,
IntegerField,
)
from sqlalchemy.orm import class_mapper
from inspector import primary_key, KeyNotFoundException
from sqlalchemy.orm.properties import RelationshipProperty, ColumnProperty
from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty
from sqlalchemy.types import (
BIGINT,
CHAR,
CLOB,
DATE,
DECIMAL,
INTEGER,
SMALLINT,
TIMESTAMP,
VARCHAR,
BigInteger,
Boolean,
DateTime,
Float,
Numeric,
String,
)
from djangorest_alchemy.fields import AlchemyRelatedField, AlchemyUriField
from .inspector import KeyNotFoundException, primary_key
class AlchemyModelSerializer(serializers.Serializer):
......@@ -60,7 +80,8 @@ class AlchemyModelSerializer(serializers.Serializer):
# URI field for get pk field
pk_field = primary_key(self.cls.__class__)
ret['href'] = AlchemyUriField(source=pk_field,
path=r.build_absolute_uri(r.path))
path=r.build_absolute_uri(r.path),
read_only=True)
except KeyNotFoundException:
pass
......@@ -81,10 +102,13 @@ class AlchemyModelSerializer(serializers.Serializer):
field_nm = str(rel_prop).split('.')[1]
# many becomes same as uselist so that
# RelatedField can iterate over the queryset
ret[field_nm] = AlchemyRelatedField(source=field_nm,
many=rel_prop.uselist,
path=r.build_absolute_uri(
r.path))
kwargs = dict(
path=r.build_absolute_uri(r.path),
read_only=True
)
if rel_prop.uselist:
kwargs['many'] = True
ret[field_nm] = AlchemyRelatedField(**kwargs)
return ret
......@@ -98,9 +122,11 @@ class AlchemyListSerializer(AlchemyModelSerializer):
pk_field = primary_key(self.cls.__class__)
request = self.context['request']
ret["href"] = AlchemyUriField(source=pk_field,
path=request.build_absolute_uri
(request.path))
ret["href"] = AlchemyUriField(
source=pk_field,
path=request.build_absolute_uri(request.path),
read_only=True,
)
except KeyNotFoundException:
return super(AlchemyListSerializer, self).get_fields()
......
import unittest
from djangorest_alchemy.apibuilder import APIModelBuilder
import mock
from djangorest_alchemy.apibuilder import APIModelBuilder
class TestAPIBuilder(unittest.TestCase):
......
......@@ -3,8 +3,10 @@
Unit test cases for AlchemyModelManager
'''
import unittest
from djangorest_alchemy.managers import AlchemyModelManager
from utils import SessionMixin, DeclarativeModel
from .utils import DeclarativeModel, SessionMixin
class ModelManager(SessionMixin, AlchemyModelManager):
......
......@@ -2,23 +2,29 @@
Integration test cases for AlchemyModelViewSet
Uses Django test client
'''
from utils import SessionMixin, DeclarativeModel, ClassicalModel
from utils import CompositeKeysModel, ChildModel
from djangorest_alchemy.managers import AlchemyModelManager
from djangorest_alchemy.viewsets import AlchemyModelViewSet
from djangorest_alchemy.mixins import ManagerMixin
from django.test import TestCase
from django.conf.urls import patterns, include, url
import datetime
import mock
import unittest
from rest_framework_nested import routers
from rest_framework import status
from rest_framework import viewsets
import mock
import six
from django.conf.urls import include, patterns, url
from django.test import TestCase
from rest_framework import status, viewsets
from rest_framework.decorators import detail_route
from rest_framework.response import Response
from rest_framework.decorators import list_route
from rest_framework_nested import routers
from djangorest_alchemy.managers import AlchemyModelManager
from djangorest_alchemy.mixins import ManagerMixin
from djangorest_alchemy.viewsets import AlchemyModelViewSet
from .utils import (
ChildModel,
ClassicalModel,
CompositeKeysModel,
DeclarativeModel,
SessionMixin,
)
RESULTS_KEY = "results"
......@@ -27,7 +33,6 @@ PAGE_KEY = "page"
class PrimaryKeyMixin(object):
def get_other_pks(self, request):
pks = {
'pk1': request.META.get('PK1'),
......@@ -48,11 +53,14 @@ class DeclModelViewSet(AlchemyModelViewSet):
manager_class = DeclarativeModelManager
paginate_by = 25
@list_route(methods=['POST'])
def list(self, request, **kwargs):
return super(DeclModelViewSet, self).list(request, **kwargs)
@detail_route(methods=['POST'])
def do_something(self, request, pk=None, **kwargs):
mgr = self.manager_factory()
# Delegate to manager method
mgr.do_something(request.DATA, pk=pk, **kwargs)
mgr.do_something(request.data, pk=pk, **kwargs)
return Response({'status': 'did_something'}, status=status.HTTP_200_OK)
......@@ -79,6 +87,7 @@ class ChildModelManager(SessionMixin, AlchemyModelManager):
class ChildModelViewSet(AlchemyModelViewSet):
manager_class = ChildModelManager
viewset_router = routers.SimpleRouter()
viewset_router.register(r'api/declmodels', DeclModelViewSet,
base_name='test-decl')
......@@ -97,38 +106,37 @@ urlpatterns = patterns('',
url(r'^', include(viewset_router.urls)),
url(r'^', include(child_router.urls)),
)
print viewset_router.urls
class TestAlchemyViewSetIntegration(TestCase):
class TestAlchemyViewSetIntegration(TestCase):
def test_decl_list(self):
resp = self.client.get('/api/declmodels/')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(type(resp.data) is dict)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertTrue(len(resp.data[RESULTS_KEY]) == 1)
self.assertTrue(resp.data[COUNT_KEY] == 1)
self.assertTrue(resp.data[PAGE_KEY] == 25)
def test_decl_retrieve(self):
resp = self.client.get('/api/declmodels/1/')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(not type(resp.data) is list)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertEqual(resp.data['declarativemodel_id'], 1)
self.assertEqual(resp.data['field'], 'test')
self.assertIsInstance(resp.data['datetime'], datetime.datetime)
self.assertIsInstance(resp.data['datetime'], six.string_types)
self.assertIsInstance(resp.data['floatfield'], float)
self.assertTrue(isinstance(resp.data['bigintfield'], (int, long)))
self.assertIsInstance(resp.data['bigintfield'], six.integer_types)
def test_classical_list(self):
resp = self.client.get('/api/clsmodels/?field=test')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(type(resp.data) is dict)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertTrue(len(resp.data[RESULTS_KEY]) == 1)
def test_classical_retrieve(self):
resp = self.client.get('/api/clsmodels/1/')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(not type(resp.data) is list)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertEqual(resp.data['classicalmodel_id'], 1)
self.assertEqual(resp.data['field'], 'test')
......@@ -139,15 +147,15 @@ class TestAlchemyViewSetIntegration(TestCase):
def test_with_multiple_pk_retrieve(self):
resp = self.client.get('/api/compositemodels/1/',
PK1='ABCD', PK2='WXYZ')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(not type(resp.data) is list)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertEqual(resp.data['compositekeysmodel_id'], 1)
self.assertEqual(resp.data['pk1'], 'ABCD')
self.assertEqual(resp.data['pk2'], 'WXYZ')
def test_hierarchical_multiple_pk_retrieve(self):
resp = self.client.get('/api/declmodels/1/childmodels/2/')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp.data['childmodel_id'], 2)
self.assertEqual(resp.data['parent_id'], 1)
......@@ -157,49 +165,48 @@ class TestAlchemyViewSetIntegration(TestCase):
def test_basic_filter(self):
resp = self.client.get('/api/declmodels/?field=test')
print resp.content
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(type(resp.data) is dict)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertTrue(len(resp.data[RESULTS_KEY]) == 1)
def test_invalid_filter(self):
resp = self.client.get('/api/declmodels/?field=invalid')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(type(resp.data) is dict)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertTrue(len(resp.data[RESULTS_KEY]) == 0)
def test_basic_pagination(self):
resp = self.client.get('/api/declmodels/?page=1')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(type(resp.data) is dict)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertTrue(len(resp.data[RESULTS_KEY]) == 1)
resp = self.client.get('/api/declmodels/?page=last')
self.assertTrue(resp.status_code is status.HTTP_200_OK)
self.assertTrue(type(resp.data) is dict)
self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertIsInstance(resp.data, dict)
self.assertTrue(len(resp.data[RESULTS_KEY]) == 1)
def test_invalid_pagination(self):
resp = self.client.get('/api/declmodels/?page=foo')
self.assertTrue(resp.status_code is status.HTTP_400_BAD_REQUEST)
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
#
# Action methods
#
#def test_action_method(self):
# resp = self.client.post('/api/declmodels/1/do_something/')
# self.assertTrue(resp.status_code is status.HTTP_200_OK)
def test_action_method(self):
resp = self.client.post('/api/declmodels/1/do_something/')
self.assertEqual(resp.status_code, status.HTTP_200_OK)
class TestAlchemyViewSetUnit(unittest.TestCase):
def test_manager_factory(self):
'''
Test if manager_factory returns back appropriate instance
This shows how you can override manager_factory
and instantiate your own manager
'''
class MockManager(AlchemyModelManager):
model_class = mock.Mock()
......@@ -236,7 +243,7 @@ class TestAlchemyViewSetUnit(unittest.TestCase):
viewset = MockViewSet()
pks = viewset.get_other_pks(mock.Mock())
self.assertIsNotNone(pks)
self.assertTrue(isinstance(pks, dict))
self.assertIsInstance(pks, dict)
def test_action_methods_manager_mixin(self):
'''
......@@ -295,7 +302,7 @@ class TestAlchemyViewSetUnit(unittest.TestCase):
manager_class = MockManager
mock_request = mock.Mock()
mock_request.DATA = {}
mock_request.data = {}
viewset = MockViewSet()
r = viewset.action_method(mock_request)
......@@ -324,7 +331,7 @@ class TestAlchemyViewSetUnit(unittest.TestCase):
manager_class = MockManager
mock_request = mock.Mock()
mock_request.DATA = {}
mock_request.data = {}
viewset = MockViewSet()
self.assertRaises(ValueError, viewset.method_name, mock_request)
......@@ -2,15 +2,14 @@
Model and manager test dummies
'''
from sqlalchemy import create_engine
from sqlalchemy import MetaData, Table, Column, ForeignKey
from sqlalchemy.types import INTEGER, String, DateTime, Float, BigInteger
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import mapper
from sqlalchemy.orm import relationship
import datetime
from sqlalchemy import Column, ForeignKey, MetaData, Table, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import mapper, relationship, sessionmaker
from sqlalchemy.types import INTEGER, BigInteger, DateTime, Float, String
engine = create_engine('sqlite://', echo=False)
Base = declarative_base()
......
......@@ -3,15 +3,15 @@ Base AlchemyViewSet which provides
the necessary plumbing to interface with
AlchemyModelSerializer and AlchemyModelManager
'''
from rest_framework import viewsets
from django.core.paginator import InvalidPage
from rest_framework import status, viewsets
from rest_framework.response import Response
from rest_framework import status
from djangorest_alchemy.serializers import AlchemyModelSerializer
from djangorest_alchemy.serializers import AlchemyListSerializer
from djangorest_alchemy.mixins import MultipleObjectMixin
from djangorest_alchemy.mixins import ManagerMixin
from django.core.paginator import InvalidPage
from djangorest_alchemy.mixins import ManagerMixin, MultipleObjectMixin
from djangorest_alchemy.serializers import (
AlchemyListSerializer,
AlchemyModelSerializer,
)
class AlchemyModelViewSet(MultipleObjectMixin, ManagerMixin, viewsets.ViewSet):
......@@ -29,14 +29,12 @@ class AlchemyModelViewSet(MultipleObjectMixin, ManagerMixin, viewsets.ViewSet):
if multiple:
s = AlchemyListSerializer(instance=queryset,
model_class=model_class,
context=context)
context=context,
many=True)
else:
s = AlchemyModelSerializer(instance=queryset,
model_class=model_class,
context=context)
#s.is_valid()
#print s.errors
return s