Skip to content

Allow providing get_schema_kwargs as a function #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions flask_rest_jsonapi/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ def __new__(cls):

return super(Resource, cls).__new__(cls)

def _access_kwargs(self, name, args, kwargs):
"""
Gets the kwargs dictionary with the provided name. This can be implemented as
a dictionary *or* a function, so we have to handle both possibilities
"""
# Access the field
val = getattr(self, name, dict())

if callable(val):
# If it's a function, call it and validate its result
schema_kwargs = val(args, kwargs)
if not isinstance(schema_kwargs, dict):
raise TypeError(
'The return value of the "{}" function must be a dictionary of kwargs'
)
else:
# If it's a dictionary, use it directly
schema_kwargs = val
if not isinstance(schema_kwargs, dict):
raise TypeError(
'The value of the "{}" class variable must be a dictionary of kwargs'
)

return schema_kwargs

@jsonapi_exception_formatter
def dispatch_request(self, *args, **kwargs):
"""Logic of how to handle a request"""
Expand Down Expand Up @@ -118,7 +143,9 @@ def get(self, *args, **kwargs):

objects_count, objects = self.get_collection(qs, kwargs)

schema_kwargs = getattr(self, 'get_schema_kwargs', dict())
# get_schema_kwargs can be a class variable or a function
schema_kwargs = self._access_kwargs('get_schema_kwargs', args, kwargs)
schema_kwargs.update()
schema_kwargs.update({'many': True})

self.before_marshmallow(args, kwargs)
Expand Down Expand Up @@ -149,8 +176,9 @@ def post(self, *args, **kwargs):

qs = QSManager(request.args, self.schema)

schema_kwargs = self._access_kwargs('post_schema_kwargs', args, kwargs)
schema = compute_schema(self.schema,
getattr(self, 'post_schema_kwargs', dict()),
schema_kwargs,
qs,
qs.include)

Expand Down Expand Up @@ -230,8 +258,9 @@ def get(self, *args, **kwargs):

self.before_marshmallow(args, kwargs)

schema_kwargs = self._access_kwargs('get_schema_kwargs', args, kwargs)
schema = compute_schema(self.schema,
getattr(self, 'get_schema_kwargs', dict()),
schema_kwargs,
qs,
qs.include)

Expand All @@ -247,7 +276,7 @@ def patch(self, *args, **kwargs):
json_data = request.get_json() or {}

qs = QSManager(request.args, self.schema)
schema_kwargs = getattr(self, 'patch_schema_kwargs', dict())
schema_kwargs = self._access_kwargs('patch_schema_kwargs', args, kwargs)
schema_kwargs.update({'partial': True})

self.before_marshmallow(args, kwargs)
Expand Down
55 changes: 54 additions & 1 deletion tests/test_sqlalchemy_data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy import create_engine, Column, Integer, DateTime, String, ForeignKey
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.ext.declarative import declarative_base
from flask import Blueprint, make_response, json
from flask import Blueprint, make_response, json, Flask
from marshmallow_jsonapi.flask import Schema, Relationship
from marshmallow import Schema as MarshmallowSchema
from marshmallow_jsonapi import fields
Expand Down Expand Up @@ -457,6 +457,7 @@ def register_routes(client, app, api_blueprint, person_list, person_detail, pers
api.route(string_json_attribute_person_detail, 'string_json_attribute_person_detail',
'/string_json_attribute_persons/<int:person_id>')
api.init_app(app)
return api


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -647,6 +648,58 @@ def test_get_list_disable_pagination(client, register_routes):
response = client.get('/persons' + '?' + querystring, content_type='application/vnd.api+json')
assert response.status_code == 200

def test_get_list_class_kwargs(session, person, person_schema, person_model, computer_list):
"""
Test a resource that defines its get_schema_kwargs as a dictionary class variable
"""
class PersonDetail(ResourceDetail):
schema = person_schema
data_layer = {
'model': person_model,
'session': session,
'url_field': 'person_id'
}

get_schema_kwargs = dict(
exclude=['name']
)

app = Flask('test')
api = Api(app=app)
api.route(PersonDetail, 'api.person_detail', '/persons/<int:person_id>')
api.route(computer_list, 'api.computer_list', '/computers', '/persons/<int:person_id>/computers')
api.init_app(app)

ret = app.test_client().get('/persons/{}'.format(person.person_id))

assert 'name' not in ret.json['data']['attributes']

def test_get_list_func_kwargs(session, person, person_schema, person_model, computer_list):
"""
Test a resource that defines its get_schema_kwargs as a function
"""
class PersonDetail(ResourceDetail):
schema = person_schema
data_layer = {
'model': person_model,
'session': session,
'url_field': 'person_id'
}

def get_schema_kwargs(self, args, kwargs):
return dict(
exclude=['name']
)

app = Flask('test')
api = Api(app=app)
api.route(PersonDetail, 'api.person_detail', '/persons/<int:person_id>')
api.route(computer_list, 'api.computer_list', '/computers', '/persons/<int:person_id>/computers')
api.init_app(app)

ret = app.test_client().get('/persons/{}'.format(person.person_id))

assert 'name' not in ret.json['data']['attributes']

def test_head_list(client, register_routes):
with client:
Expand Down