Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Allow for sorting/filtering of JSON objects when using PostgreSQL #247 #342

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 120 additions & 24 deletions dynamic_rest/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from django.core.exceptions import ValidationError as InternalValidationError
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Prefetch, Manager
from django.db.models.expressions import RawSQL, OrderBy
import six
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, NullBooleanField
from rest_framework.fields import BooleanField, NullBooleanField, JSONField
from rest_framework.filters import BaseFilterBackend, OrderingFilter

from dynamic_rest.utils import is_truthy
from dynamic_rest.conf import settings
from dynamic_rest.datastructures import TreeMap
Expand Down Expand Up @@ -127,6 +127,15 @@ def generate_query_key(self, serializer):

# Recurse into nested field
s = getattr(field, 'serializer', None)
if isinstance(field, JSONField):
# If a json field is found, append any terms following
j = i+1
while j < len(self.field):
rewritten.append(self.field[j])
j += 1
if self.operator:
rewritten.append(self.operator)
return ('__'.join(rewritten), field)
if isinstance(s, serializers.ListSerializer):
s = s.child
if not s:
Expand Down Expand Up @@ -294,33 +303,41 @@ def _filters_to_query(self, includes, excludes, serializer, q=None):
q: Q() object (optional)

Returns:
Q() instance or None if no inclusion or exclusion filters
were specified.
Tuple of:
* Q() instance or None if no inclusion or exclusion filters
were specified.
* dictionary of {(field,): (operator, value)} for any json fields
"""

def rewrite_filters(filters, serializer):
out = {}
json_out = {}
for k, node in six.iteritems(filters):
filter_key, field = node.generate_query_key(serializer)
if isinstance(field, (BooleanField, NullBooleanField)):
node.value = is_truthy(node.value)
out[filter_key] = node.value

return out
if isinstance(field, JSONField):
json_out[tuple(node.field)] = (node.operator, node.value)
else:
out[filter_key] = node.value
return out, json_out

q = q or Q()

json_extras = None

if not includes and not excludes:
return None
return None, None

if includes:
includes = rewrite_filters(includes, serializer)
includes, json_extras = rewrite_filters(includes, serializer)
q &= Q(**includes)
if excludes:
excludes = rewrite_filters(excludes, serializer)
excludes, json_extras = rewrite_filters(excludes, serializer)
for k, v in six.iteritems(excludes):
q &= ~Q(**{k: v})
return q
return q, json_extras

def _create_prefetch(self, source, queryset):
return Prefetch(source, queryset=queryset)
Expand Down Expand Up @@ -569,7 +586,7 @@ def _build_queryset(
queryset = queryset.only(*only)

# add request filters
query = self._filters_to_query(
query, json_extras = self._filters_to_query(
includes=filters.get('_include'),
excludes=filters.get('_exclude'),
serializer=serializer
Expand All @@ -579,12 +596,16 @@ def _build_queryset(
if extra_filters:
query = extra_filters if not query else extra_filters & query

if query:
if query or json_extras:
# Convert internal django ValidationError to
# APIException-based one in order to resolve validation error
# from 500 status code to 400.
try:
queryset = queryset.filter(query)

if json_extras:
extra_queries = self._get_json_queries(json_extras)
queryset = queryset.extra(where=extra_queries)
except InternalValidationError as e:
raise ValidationError(
dict(e) if hasattr(e, 'error_dict') else list(e)
Expand Down Expand Up @@ -620,6 +641,52 @@ def _build_queryset(
queryset._using_prefetches = prefetches
return queryset

def _get_json_queries(self, json_extras):
extra_queries = []

for json_field_names, (operator, value) in six.iteritems(json_extras):
if not operator:
query_operator = '='
value = "'{}'".format(value)
elif operator in ('startswith', 'istartswith'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'{}%%'".format(value)
elif operator in ('endswith', 'iendswith'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'%%{}'".format(value)
elif operator in ('contains', 'icontains'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'%%{}%%'".format(value)

else:
raise InternalValidationError(
f"""Unsupported filter operation for nested JSON fields:
{operator}"""
)

extra_query = []

for idx, k in enumerate(json_field_names):
if idx == 0:
extra_query.append(k)
else:
extra_query.append("'{}'".format(k))

if idx == len(json_field_names) - 1:
continue
# the ->> operator returns a raw value
elif idx == len(json_field_names) - 2:
extra_query.append('->>')
# the -> operator returns JSON
else:
extra_query.append('->')

extra_query.append(query_operator)
extra_query.append(value)
extra_queries.append(' '.join(extra_query))

return extra_queries


class FastDynamicFilterBackend(DynamicFilterBackend):
def _create_prefetch(self, source, queryset):
Expand Down Expand Up @@ -665,7 +732,16 @@ def filter_queryset(self, request, queryset, view):
"""
self.ordering_param = view.SORT

ordering = self.get_ordering(request, queryset, view)
ordering, nested = self.get_ordering(request, queryset, view)
if ordering and nested:
ordering_str = ''.join(ordering)
if ordering_str.startswith('-'):
return queryset.order_by(
OrderBy(RawSQL('LOWER( %s )' % (ordering_str[1:]), nested),
descending=True))
return queryset.order_by(
OrderBy(RawSQL('LOWER(%s)' % (ordering_str), nested),
descending=False))
if ordering:
queryset = queryset.order_by(*ordering)
if any(['__' in o for o in ordering]):
Expand All @@ -681,11 +757,13 @@ def get_ordering(self, request, queryset, view):
This method overwrites the DRF default so it can parse the array.
"""
params = view.get_request_feature(view.SORT)
nested = []
if params:
fields = [param.strip() for param in params]
valid_ordering, invalid_ordering = self.remove_invalid_fields(
queryset, fields, view
)
valid_ordering, invalid_ordering, nested = \
self.remove_invalid_fields(
queryset, fields, view
)

# if any of the sort fields are invalid, throw an error.
# else return the ordering
Expand All @@ -694,10 +772,10 @@ def get_ordering(self, request, queryset, view):
"Invalid filter field: %s" % invalid_ordering
)
else:
return valid_ordering
return valid_ordering, nested

# No sorting was included
return self.get_default_ordering(view)
return self.get_default_ordering(view), nested

def remove_invalid_fields(self, queryset, fields, view):
"""Remove invalid fields from an ordering.
Expand All @@ -715,14 +793,14 @@ def remove_invalid_fields(self, queryset, fields, view):
stripped_term = term.lstrip('-')
# add back the '-' add the end if necessary
reverse_sort_term = '' if len(stripped_term) is len(term) else '-'
ordering = self.ordering_for(stripped_term, view)
ordering, nested = self.ordering_for(stripped_term, view)

if ordering:
valid_orderings.append(reverse_sort_term + ordering)
else:
invalid_orderings.append(term)

return valid_orderings, invalid_orderings
return valid_orderings, invalid_orderings, nested

def ordering_for(self, term, view):
"""
Expand All @@ -732,7 +810,7 @@ def ordering_for(self, term, view):
Raise ImproperlyConfigured if serializer_class not set on view
"""
if not self._is_allowed_term(term, view):
return None
return None, None

serializer = self._get_serializer_class(view)()
serializer_chain = term.split('.')
Expand All @@ -742,9 +820,27 @@ def ordering_for(self, term, view):
for segment in serializer_chain[:-1]:
field = serializer.get_all_fields().get(segment)

# If its a JSONField, construct a RawSQL command in the form
# of 'jsonField->{}'.format('nestedField')' or
# 'jsonField->>{}->{}'.format('nested','doubleNested')
if field and isinstance(field, JSONField):
json_chain_start = str(segment)
json_chain = ''
nested = []
first = True
for nterm in serializer_chain[1:]:
if first:
json_chain += '->>%s'
first = False
else:
json_chain = '->%s' + json_chain
nested.append(nterm)
json_chain = json_chain_start + json_chain
return json_chain, nested

if not (field and field.source != '*' and
isinstance(field, DynamicRelationField)):
return None
return None, None

model_chain.append(field.source or segment)

Expand All @@ -754,11 +850,11 @@ def ordering_for(self, term, view):
last_field = serializer.get_all_fields().get(last_segment)

if not last_field or last_field.source == '*':
return None
return None, None

model_chain.append(last_field.source or last_segment)

return '__'.join(model_chain)
return '__'.join(model_chain), None

def _is_allowed_term(self, term, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
Expand Down
22 changes: 22 additions & 0 deletions tests/migrations/0007_recipe_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from django.db import migrations, models
from django.contrib.postgres.fields import JSONField


class Migration(migrations.Migration):

dependencies = [
('tests', '0006_auto_20210921_1026'),
]

operations = [
migrations.CreateModel(
name='recipe',
fields=[
('name', models.CharField(max_length=60)),
('ingredients', JSONField(null=True))
]
),
]
6 changes: 6 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.fields import JSONField
from django.db import models


Expand Down Expand Up @@ -137,3 +138,8 @@ class Part(models.Model):
car = models.ForeignKey(Car, on_delete=models.CASCADE)
name = models.CharField(max_length=60)
country = models.ForeignKey(Country, on_delete=models.CASCADE)


class Recipe(models.Model):
name = models.CharField(max_length=60)
ingredients = JSONField(null=True)
7 changes: 7 additions & 0 deletions tests/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Part,
Permission,
Profile,
Recipe,
User,
Zebra,
)
Expand Down Expand Up @@ -323,3 +324,9 @@ class Meta:
model = Car
fields = ('id', 'name', 'country', 'parts')
deferred_fields = ('name', 'country', 'parts')


class RecipeSerializer(DynamicModelSerializer):
class Meta:
model = Recipe
fields = ('name', 'ingredients')
Loading