Skip to content

Commit

Permalink
allow for sorting/filtering of JSON objects when using PostgreSQL #247
Browse files Browse the repository at this point in the history
  • Loading branch information
sa-mmendivil committed Dec 5, 2022
1 parent cdf804f commit 04a9832
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 30 deletions.
144 changes: 120 additions & 24 deletions dynamic_rest/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from django.core.exceptions import ValidationError as InternalValidationError
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Prefetch, Manager
from django.db.models.expressions import RawSQL, OrderBy
import six
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, NullBooleanField
from rest_framework.fields import BooleanField, NullBooleanField, JSONField
from rest_framework.filters import BaseFilterBackend, OrderingFilter

from dynamic_rest.utils import is_truthy
from dynamic_rest.conf import settings
from dynamic_rest.datastructures import TreeMap
Expand Down Expand Up @@ -127,6 +127,15 @@ def generate_query_key(self, serializer):

# Recurse into nested field
s = getattr(field, 'serializer', None)
if isinstance(field, JSONField):
# If a json field is found, append any terms following
j = i+1
while j < len(self.field):
rewritten.append(self.field[j])
j += 1
if self.operator:
rewritten.append(self.operator)
return ('__'.join(rewritten), field)
if isinstance(s, serializers.ListSerializer):
s = s.child
if not s:
Expand Down Expand Up @@ -294,33 +303,41 @@ def _filters_to_query(self, includes, excludes, serializer, q=None):
q: Q() object (optional)
Returns:
Q() instance or None if no inclusion or exclusion filters
were specified.
Tuple of:
* Q() instance or None if no inclusion or exclusion filters
were specified.
* dictionary of {(field,): (operator, value)} for any json fields
"""

def rewrite_filters(filters, serializer):
out = {}
json_out = {}
for k, node in six.iteritems(filters):
filter_key, field = node.generate_query_key(serializer)
if isinstance(field, (BooleanField, NullBooleanField)):
node.value = is_truthy(node.value)
out[filter_key] = node.value

return out
if isinstance(field, JSONField):
json_out[tuple(node.field)] = (node.operator, node.value)
else:
out[filter_key] = node.value
return out, json_out

q = q or Q()

json_extras = None

if not includes and not excludes:
return None
return None, None

if includes:
includes = rewrite_filters(includes, serializer)
includes, json_extras = rewrite_filters(includes, serializer)
q &= Q(**includes)
if excludes:
excludes = rewrite_filters(excludes, serializer)
excludes, json_extras = rewrite_filters(excludes, serializer)
for k, v in six.iteritems(excludes):
q &= ~Q(**{k: v})
return q
return q, json_extras

def _create_prefetch(self, source, queryset):
return Prefetch(source, queryset=queryset)
Expand Down Expand Up @@ -569,7 +586,7 @@ def _build_queryset(
queryset = queryset.only(*only)

# add request filters
query = self._filters_to_query(
query, json_extras = self._filters_to_query(
includes=filters.get('_include'),
excludes=filters.get('_exclude'),
serializer=serializer
Expand All @@ -579,12 +596,16 @@ def _build_queryset(
if extra_filters:
query = extra_filters if not query else extra_filters & query

if query:
if query or json_extras:
# Convert internal django ValidationError to
# APIException-based one in order to resolve validation error
# from 500 status code to 400.
try:
queryset = queryset.filter(query)

if json_extras:
extra_queries = self._get_json_queries(json_extras)
queryset = queryset.extra(where=extra_queries)
except InternalValidationError as e:
raise ValidationError(
dict(e) if hasattr(e, 'error_dict') else list(e)
Expand Down Expand Up @@ -620,6 +641,52 @@ def _build_queryset(
queryset._using_prefetches = prefetches
return queryset

def _get_json_queries(self, json_extras):
extra_queries = []

for json_field_names, (operator, value) in six.iteritems(json_extras):
if not operator:
query_operator = '='
value = "'{}'".format(value)
elif operator in ('startswith', 'istartswith'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'{}%%'".format(value)
elif operator in ('endswith', 'iendswith'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'%%{}'".format(value)
elif operator in ('contains', 'icontains'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'%%{}%%'".format(value)

else:
raise InternalValidationError(
f"""Unsupported filter operation for nested JSON fields:
{operator}"""
)

extra_query = []

for idx, k in enumerate(json_field_names):
if idx == 0:
extra_query.append(k)
else:
extra_query.append("'{}'".format(k))

if idx == len(json_field_names) - 1:
continue
# the ->> operator returns a raw value
elif idx == len(json_field_names) - 2:
extra_query.append('->>')
# the -> operator returns JSON
else:
extra_query.append('->')

extra_query.append(query_operator)
extra_query.append(value)
extra_queries.append(' '.join(extra_query))

return extra_queries


class FastDynamicFilterBackend(DynamicFilterBackend):
def _create_prefetch(self, source, queryset):
Expand Down Expand Up @@ -665,7 +732,16 @@ def filter_queryset(self, request, queryset, view):
"""
self.ordering_param = view.SORT

ordering = self.get_ordering(request, queryset, view)
ordering, nested = self.get_ordering(request, queryset, view)
if ordering and nested:
ordering_str = ''.join(ordering)
if ordering_str.startswith('-'):
return queryset.order_by(
OrderBy(RawSQL('LOWER( %s )' % (ordering_str[1:]), nested),
descending=True))
return queryset.order_by(
OrderBy(RawSQL('LOWER(%s)' % (ordering_str), nested),
descending=False))
if ordering:
queryset = queryset.order_by(*ordering)
if any(['__' in o for o in ordering]):
Expand All @@ -681,11 +757,13 @@ def get_ordering(self, request, queryset, view):
This method overwrites the DRF default so it can parse the array.
"""
params = view.get_request_feature(view.SORT)
nested = []
if params:
fields = [param.strip() for param in params]
valid_ordering, invalid_ordering = self.remove_invalid_fields(
queryset, fields, view
)
valid_ordering, invalid_ordering, nested = \
self.remove_invalid_fields(
queryset, fields, view
)

# if any of the sort fields are invalid, throw an error.
# else return the ordering
Expand All @@ -694,10 +772,10 @@ def get_ordering(self, request, queryset, view):
"Invalid filter field: %s" % invalid_ordering
)
else:
return valid_ordering
return valid_ordering, nested

# No sorting was included
return self.get_default_ordering(view)
return self.get_default_ordering(view), nested

def remove_invalid_fields(self, queryset, fields, view):
"""Remove invalid fields from an ordering.
Expand All @@ -715,14 +793,14 @@ def remove_invalid_fields(self, queryset, fields, view):
stripped_term = term.lstrip('-')
# add back the '-' add the end if necessary
reverse_sort_term = '' if len(stripped_term) is len(term) else '-'
ordering = self.ordering_for(stripped_term, view)
ordering, nested = self.ordering_for(stripped_term, view)

if ordering:
valid_orderings.append(reverse_sort_term + ordering)
else:
invalid_orderings.append(term)

return valid_orderings, invalid_orderings
return valid_orderings, invalid_orderings, nested

def ordering_for(self, term, view):
"""
Expand All @@ -732,7 +810,7 @@ def ordering_for(self, term, view):
Raise ImproperlyConfigured if serializer_class not set on view
"""
if not self._is_allowed_term(term, view):
return None
return None, None

serializer = self._get_serializer_class(view)()
serializer_chain = term.split('.')
Expand All @@ -742,9 +820,27 @@ def ordering_for(self, term, view):
for segment in serializer_chain[:-1]:
field = serializer.get_all_fields().get(segment)

# If its a JSONField, construct a RawSQL command in the form
# of 'jsonField->{}'.format('nestedField')' or
# 'jsonField->>{}->{}'.format('nested','doubleNested')
if field and isinstance(field, JSONField):
json_chain_start = str(segment)
json_chain = ''
nested = []
first = True
for nterm in serializer_chain[1:]:
if first:
json_chain += '->>%s'
first = False
else:
json_chain = '->%s' + json_chain
nested.append(nterm)
json_chain = json_chain_start + json_chain
return json_chain, nested

if not (field and field.source != '*' and
isinstance(field, DynamicRelationField)):
return None
return None, None

model_chain.append(field.source or segment)

Expand All @@ -754,11 +850,11 @@ def ordering_for(self, term, view):
last_field = serializer.get_all_fields().get(last_segment)

if not last_field or last_field.source == '*':
return None
return None, None

model_chain.append(last_field.source or last_segment)

return '__'.join(model_chain)
return '__'.join(model_chain), None

def _is_allowed_term(self, term, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
Expand Down
22 changes: 22 additions & 0 deletions tests/migrations/0007_recipe_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals

from django.db import migrations, models
from django.contrib.postgres.fields import JSONField


class Migration(migrations.Migration):

dependencies = [
('tests', '0006_auto_20210921_1026'),
]

operations = [
migrations.CreateModel(
name='recipe',
fields=[
('name', models.CharField(max_length=60)),
('ingredients', JSONField(null=True))
]
),
]
6 changes: 6 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.contrib.postgres.fields import JSONField
from django.db import models


Expand Down Expand Up @@ -137,3 +138,8 @@ class Part(models.Model):
car = models.ForeignKey(Car, on_delete=models.CASCADE)
name = models.CharField(max_length=60)
country = models.ForeignKey(Country, on_delete=models.CASCADE)


class Recipe(models.Model):
name = models.CharField(max_length=60)
ingredients = JSONField(null=True)
7 changes: 7 additions & 0 deletions tests/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Part,
Permission,
Profile,
Recipe,
User,
Zebra,
)
Expand Down Expand Up @@ -323,3 +324,9 @@ class Meta:
model = Car
fields = ('id', 'name', 'country', 'parts')
deferred_fields = ('name', 'country', 'parts')


class RecipeSerializer(DynamicModelSerializer):
class Meta:
model = Recipe
fields = ('name', 'ingredients')
Loading

0 comments on commit 04a9832

Please sign in to comment.