From 840ae23137b295a1c11baa7362bf902b929db9e0 Mon Sep 17 00:00:00 2001
From: KimSia Sim <kimsia@oppoin.com>
Date: Fri, 18 Sep 2020 15:01:17 +0800
Subject: [PATCH] Allow JSONField to be filtered as well

Inclusive of unit tests

closed https://github.com/AltSchool/dynamic-rest/issues/296
---
 dynamic_rest/filters.py | 303 +++++++++++++++-------------------------
 tests/models.py         |  61 +++-----
 tests/setup.py          | 237 ++++++++++++++-----------------
 tests/test_generic.py   | 161 +++++++++------------
 4 files changed, 304 insertions(+), 458 deletions(-)

diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py
index 495c6d62..be5a4d52 100644
--- a/dynamic_rest/filters.py
+++ b/dynamic_rest/filters.py
@@ -1,27 +1,22 @@
 """This module contains custom filter backends."""
 
-from django.core.exceptions import ValidationError as InternalValidationError
-from django.core.exceptions import ImproperlyConfigured
-from django.db.models import Q, Prefetch, Manager
 import six
+from django.core.exceptions import ImproperlyConfigured
+from django.core.exceptions import ValidationError as InternalValidationError
+from django.db.models import Manager, Prefetch, Q
 from rest_framework import serializers
 from rest_framework.exceptions import ValidationError
-from rest_framework.fields import BooleanField, NullBooleanField
+from rest_framework.fields import BooleanField, JSONField, NullBooleanField
 from rest_framework.filters import BaseFilterBackend, OrderingFilter
 
-from dynamic_rest.utils import is_truthy
 from dynamic_rest.conf import settings
 from dynamic_rest.datastructures import TreeMap
 from dynamic_rest.fields import DynamicRelationField
-from dynamic_rest.meta import (
-    get_model_field,
-    is_field_remote,
-    is_model_field,
-    get_related_model
-)
+from dynamic_rest.meta import get_model_field, get_related_model, is_field_remote, is_model_field
 from dynamic_rest.patches import patch_prefetch_one_level
-from dynamic_rest.prefetch import FastQuery, FastPrefetch
+from dynamic_rest.prefetch import FastPrefetch, FastQuery
 from dynamic_rest.related import RelatedObject
+from dynamic_rest.utils import is_truthy
 
 patch_prefetch_one_level()
 
@@ -39,7 +34,6 @@ def has_joins(queryset):
 
 
 class FilterNode(object):
-
     def __init__(self, field, operator, value):
         """Create an object representing a filter, to be stored in a TreeMap.
 
@@ -65,10 +59,7 @@ def __init__(self, field, operator, value):
 
     @property
     def key(self):
-        return '%s%s' % (
-            '__'.join(self.field),
-            '__' + self.operator if self.operator else ''
-        )
+        return "%s%s" % ("__".join(self.field), "__" + self.operator if self.operator else "")
 
     def generate_query_key(self, serializer):
         """Get the key that can be passed to Django's filter method.
@@ -90,7 +81,13 @@ def generate_query_key(self, serializer):
         last = len(self.field) - 1
         s = serializer
         field = None
+        jsonfield_recurse = False
         for i, field_name in enumerate(self.field):
+            # Note: this is to handle jsonfield for recursive filtering
+            if jsonfield_recurse:
+                rewritten.append(field_name)
+                if i == last:
+                    break
             # Note: .fields can be empty for related serializers that aren't
             # sideloaded. Fields that are deferred also won't be present.
             # If field name isn't in serializer.fields, get full list from
@@ -98,16 +95,14 @@ def generate_query_key(self, serializer):
             # this if we have to.
             fields = s.fields
             if field_name not in fields:
-                fields = getattr(s, 'get_all_fields', lambda: {})()
+                fields = getattr(s, "get_all_fields", lambda: {})()
 
-            if field_name == 'pk':
-                rewritten.append('pk')
+            if field_name == "pk":
+                rewritten.append("pk")
                 continue
 
             if field_name not in fields:
-                raise ValidationError(
-                    "Invalid filter field: %s" % field_name
-                )
+                raise ValidationError("Invalid filter field: %s" % field_name)
 
             field = fields[field_name]
 
@@ -126,18 +121,20 @@ def generate_query_key(self, serializer):
                 break
 
             # Recurse into nested field
-            s = getattr(field, 'serializer', None)
+            s = getattr(field, "serializer", None)
             if isinstance(s, serializers.ListSerializer):
                 s = s.child
+            # Handle the field when it's a JSONField
+            elif isinstance(field, JSONField):
+                s = serializer
+                jsonfield_recurse = True
             if not s:
-                raise ValidationError(
-                    "Invalid nested filter field: %s" % field_name
-                )
+                raise ValidationError("Invalid nested filter field: %s" % field_name)
 
         if self.operator:
             rewritten.append(self.operator)
 
-        return ('__'.join(rewritten), field)
+        return ("__".join(rewritten), field)
 
 
 class DynamicFilterBackend(BaseFilterBackend):
@@ -152,28 +149,28 @@ class DynamicFilterBackend(BaseFilterBackend):
     """
 
     VALID_FILTER_OPERATORS = (
-        'in',
-        'any',
-        'all',
-        'icontains',
-        'contains',
-        'startswith',
-        'istartswith',
-        'endswith',
-        'iendswith',
-        'year',
-        'month',
-        'day',
-        'week_day',
-        'regex',
-        'range',
-        'gt',
-        'lt',
-        'gte',
-        'lte',
-        'isnull',
-        'eq',
-        'iexact',
+        "in",
+        "any",
+        "all",
+        "icontains",
+        "contains",
+        "startswith",
+        "istartswith",
+        "endswith",
+        "iendswith",
+        "year",
+        "month",
+        "day",
+        "week_day",
+        "regex",
+        "range",
+        "gt",
+        "lt",
+        "gte",
+        "lte",
+        "isnull",
+        "eq",
+        "iexact",
         None,
     )
 
@@ -198,15 +195,14 @@ def filter_queryset(self, request, queryset, view):
         self.DEBUG = settings.DEBUG
 
         return self._build_queryset(
-            queryset=queryset,
-            extra_filters=extra_filters,
-            disable_prefetches=disable_prefetches,
+            queryset=queryset, extra_filters=extra_filters, disable_prefetches=disable_prefetches,
         )
 
     """
     This function was renamed and broke downstream dependencies that haven't
     been updated to use the new naming convention.
     """
+
     def _extract_filters(self, **kwargs):
         return self._get_requested_filters(**kwargs)
 
@@ -221,30 +217,27 @@ def _get_requested_filters(self, **kwargs):
 
         """
 
-        filters_map = (
-            kwargs.get('filters_map') or
-            self.view.get_request_feature(self.view.FILTER)
-        )
+        filters_map = kwargs.get("filters_map") or self.view.get_request_feature(self.view.FILTER)
 
         out = TreeMap()
 
         for spec, value in six.iteritems(filters_map):
 
             # Inclusion or exclusion?
-            if spec[0] == '-':
+            if spec[0] == "-":
                 spec = spec[1:]
-                inex = '_exclude'
+                inex = "_exclude"
             else:
-                inex = '_include'
+                inex = "_include"
 
             # for relational filters, separate out relation path part
-            if '|' in spec:
-                rel, spec = spec.split('|')
-                rel = rel.split('.')
+            if "|" in spec:
+                rel, spec = spec.split("|")
+                rel = rel.split(".")
             else:
                 rel = None
 
-            parts = spec.split('.')
+            parts = spec.split(".")
 
             # Last part could be operator, e.g. "events.capacity.gte"
             if len(parts) > 1 and parts[-1] in self.VALID_FILTER_OPERATORS:
@@ -253,19 +246,16 @@ def _get_requested_filters(self, **kwargs):
                 operator = None
 
             # All operators except 'range' and 'in' should have one value
-            if operator == 'range':
+            if operator == "range":
                 value = value[:2]
-            elif operator == 'in':
+            elif operator == "in":
                 # no-op: i.e. accept `value` as an arbitrarily long list
                 pass
             elif operator in self.VALID_FILTER_OPERATORS:
                 value = value[0]
-                if (
-                    operator == 'isnull' and
-                    isinstance(value, six.string_types)
-                ):
+                if operator == "isnull" and isinstance(value, six.string_types):
                     value = is_truthy(value)
-                elif operator == 'eq':
+                elif operator == "eq":
                     operator = None
 
             node = FilterNode(parts, operator, value)
@@ -325,12 +315,7 @@ def rewrite_filters(filters, serializer):
     def _create_prefetch(self, source, queryset):
         return Prefetch(source, queryset=queryset)
 
-    def _build_implicit_prefetches(
-        self,
-        model,
-        prefetches,
-        requirements
-    ):
+    def _build_implicit_prefetches(self, model, prefetches, requirements):
         """Build a prefetch dictionary based on internal requirements."""
 
         for source, remainder in six.iteritems(requirements):
@@ -341,16 +326,12 @@ def _build_implicit_prefetches(
             related_field = get_model_field(model, source)
             related_model = get_related_model(related_field)
 
-            queryset = self._build_implicit_queryset(
-                related_model,
-                remainder
-            ) if related_model else None
-
-            prefetches[source] = self._create_prefetch(
-                source,
-                queryset
+            queryset = (
+                self._build_implicit_queryset(related_model, remainder) if related_model else None
             )
 
+            prefetches[source] = self._create_prefetch(source, queryset)
+
         return prefetches
 
     def _make_model_queryset(self, model):
@@ -361,25 +342,14 @@ def _build_implicit_queryset(self, model, requirements):
 
         queryset = self._make_model_queryset(model)
         prefetches = {}
-        self._build_implicit_prefetches(
-            model,
-            prefetches,
-            requirements
-        )
+        self._build_implicit_prefetches(model, prefetches, requirements)
         prefetch = prefetches.values()
         queryset = queryset.prefetch_related(*prefetch).distinct()
         if self.DEBUG:
             queryset._using_prefetches = prefetches
         return queryset
 
-    def _build_requested_prefetches(
-        self,
-        prefetches,
-        requirements,
-        model,
-        fields,
-        filters
-    ):
+    def _build_requested_prefetches(self, prefetches, requirements, model, fields, filters):
         """Build a prefetch dictionary based on request requirements."""
 
         for name, field in six.iteritems(fields):
@@ -392,22 +362,19 @@ def _build_requested_prefetches(
                 continue
 
             source = field.source or name
-            if '.' in source:
-                raise ValidationError(
-                    'nested relationship values '
-                    'are not supported'
-                )
+            if "." in source:
+                raise ValidationError("nested relationship values " "are not supported")
 
             if source in prefetches:
                 # ignore duplicated sources
                 continue
 
             is_remote = is_field_remote(model, source)
-            is_id_only = getattr(field, 'id_only', lambda: False)()
+            is_id_only = getattr(field, "id_only", lambda: False)()
             if is_id_only and not is_remote:
                 continue
 
-            related_queryset = getattr(original_field, 'queryset', None)
+            related_queryset = getattr(original_field, "queryset", None)
 
             if callable(related_queryset):
                 related_queryset = related_queryset(field)
@@ -422,7 +389,7 @@ def _build_requested_prefetches(
                 serializer=field,
                 filters=filters.get(name, {}),
                 queryset=related_queryset,
-                requirements=required
+                requirements=required,
             )
 
             # Note: There can only be one prefetch per source, even
@@ -430,34 +397,27 @@ def _build_requested_prefetches(
             #       the same source. This could break in some cases,
             #       but is mostly an issue on writes when we use all
             #       fields by default.
-            prefetches[source] = self._create_prefetch(
-                source,
-                prefetch_queryset
-            )
+            prefetches[source] = self._create_prefetch(source, prefetch_queryset)
 
         return prefetches
 
-    def _get_implicit_requirements(
-        self,
-        fields,
-        requirements
-    ):
+    def _get_implicit_requirements(self, fields, requirements):
         """Extract internal prefetch requirements from serializer fields."""
         for name, field in six.iteritems(fields):
             source = field.source
             # Requires may be manually set on the field -- if not,
             # assume the field requires only its source.
-            requires = getattr(field, 'requires', None) or [source]
+            requires = getattr(field, "requires", None) or [source]
             for require in requires:
                 if not require:
                     # ignore fields with empty source
                     continue
 
-                requirement = require.split('.')
-                if requirement[-1] == '':
+                requirement = require.split(".")
+                if requirement[-1] == "":
                     # Change 'a.b.' -> 'a.b.*',
                     # supporting 'a.b.' for backwards compatibility.
-                    requirement[-1] = '*'
+                    requirement[-1] = "*"
                 requirements.insert(requirement, TreeMap(), update=True)
 
     def _get_queryset(self, queryset=None, serializer=None):
@@ -499,7 +459,7 @@ def _build_queryset(
 
         queryset = self._get_queryset(queryset=queryset, serializer=serializer)
 
-        model = getattr(serializer.Meta, 'model', None)
+        model = getattr(serializer.Meta, "model", None)
 
         if not model:
             return queryset
@@ -513,10 +473,7 @@ def _build_queryset(
         if requirements is None:
             requirements = TreeMap()
 
-        self._get_implicit_requirements(
-            fields,
-            requirements
-        )
+        self._get_implicit_requirements(fields, requirements)
 
         # Implicit requirements (i.e. via `requires`) can potentially
         # include fields that haven't been explicitly included.
@@ -524,55 +481,38 @@ def _build_queryset(
         implicitly_included = set(requirements.keys()) - set(fields.keys())
         if implicitly_included:
             all_fields = serializer.get_all_fields()
-            fields.update({
-                field: all_fields[field]
-                for field in implicitly_included
-                if field in all_fields
-            })
+            fields.update(
+                {field: all_fields[field] for field in implicitly_included if field in all_fields}
+            )
 
         if filters is None:
             filters = self._get_requested_filters()
 
         # build nested Prefetch queryset
-        self._build_requested_prefetches(
-            prefetches,
-            requirements,
-            model,
-            fields,
-            filters
-        )
+        self._build_requested_prefetches(prefetches, requirements, model, fields, filters)
 
         # build remaining prefetches out of internal requirements
         # that are not already covered by request requirements
-        self._build_implicit_prefetches(
-            model,
-            prefetches,
-            requirements
-        )
+        self._build_implicit_prefetches(model, prefetches, requirements)
 
         # use requirements at this level to limit fields selected
         # only do this for GET requests where we are not requesting the
         # entire fieldset
-        if (
-            '*' not in requirements and
-            not self.view.is_update() and
-            not self.view.is_delete()
-        ):
-            id_fields = getattr(serializer, 'get_id_fields', lambda: [])()
+        if "*" not in requirements and not self.view.is_update() and not self.view.is_delete():
+            id_fields = getattr(serializer, "get_id_fields", lambda: [])()
             # only include local model fields
             only = [
-                field for field in set(
-                    id_fields + list(requirements.keys())
-                ) if is_model_field(model, field) and
-                not is_field_remote(model, field)
+                field
+                for field in set(id_fields + list(requirements.keys()))
+                if is_model_field(model, field) and not is_field_remote(model, field)
             ]
             queryset = queryset.only(*only)
 
         # add request filters
         query = self._filters_to_query(
-            includes=filters.get('_include'),
-            excludes=filters.get('_exclude'),
-            serializer=serializer
+            includes=filters.get("_include"),
+            excludes=filters.get("_exclude"),
+            serializer=serializer,
         )
 
         # add additional filters specified by calling view
@@ -586,13 +526,11 @@ def _build_queryset(
             try:
                 queryset = queryset.filter(query)
             except InternalValidationError as e:
-                raise ValidationError(
-                    dict(e) if hasattr(e, 'error_dict') else list(e)
-                )
+                raise ValidationError(dict(e) if hasattr(e, "error_dict") else list(e))
             except Exception as e:
                 # Some other Django error in parsing the filter.
                 # Very likely a bad query, so throw a ValidationError.
-                err_msg = getattr(e, 'message', '')
+                err_msg = getattr(e, "message", "")
                 raise ValidationError(err_msg)
 
         # A serializer can have this optional function
@@ -601,11 +539,8 @@ def _build_queryset(
         # You could use this to have (for example) different
         # serializers for different subsets of a model or to
         # implement permissions which work even in sideloads
-        if hasattr(serializer, 'filter_queryset'):
-            queryset = self._serializer_filter(
-                serializer=serializer,
-                queryset=queryset
-            )
+        if hasattr(serializer, "filter_queryset"):
+            queryset = self._serializer_filter(serializer=serializer, queryset=queryset)
 
         # add prefetches and remove duplicates if necessary
         prefetch = prefetches.values()
@@ -627,8 +562,7 @@ def _create_prefetch(self, source, queryset):
 
     def _get_queryset(self, queryset=None, serializer=None):
         queryset = super(FastDynamicFilterBackend, self)._get_queryset(
-            queryset=queryset,
-            serializer=serializer
+            queryset=queryset, serializer=serializer
         )
 
         if not isinstance(queryset, FastQuery):
@@ -637,15 +571,11 @@ def _get_queryset(self, queryset=None, serializer=None):
         return queryset
 
     def _make_model_queryset(self, model):
-        queryset = super(FastDynamicFilterBackend, self)._make_model_queryset(
-            model
-        )
+        queryset = super(FastDynamicFilterBackend, self)._make_model_queryset(model)
         return FastQuery(queryset)
 
     def _serializer_filter(self, serializer=None, queryset=None):
-        queryset.queryset = serializer.filter_queryset(
-            queryset.queryset
-        )
+        queryset.queryset = serializer.filter_queryset(queryset.queryset)
         return queryset
 
 
@@ -668,7 +598,7 @@ def filter_queryset(self, request, queryset, view):
         ordering = self.get_ordering(request, queryset, view)
         if ordering:
             queryset = queryset.order_by(*ordering)
-            if any(['__' in o for o in ordering]):
+            if any(["__" in o for o in ordering]):
                 # add distinct() to remove duplicates
                 # in case of order-by-related
                 queryset = queryset.distinct()
@@ -683,16 +613,12 @@ def get_ordering(self, request, queryset, view):
         params = view.get_request_feature(view.SORT)
         if params:
             fields = [param.strip() for param in params]
-            valid_ordering, invalid_ordering = self.remove_invalid_fields(
-                queryset, fields, view
-            )
+            valid_ordering, invalid_ordering = self.remove_invalid_fields(queryset, fields, view)
 
             # if any of the sort fields are invalid, throw an error.
             # else return the ordering
             if invalid_ordering:
-                raise ValidationError(
-                    "Invalid filter field: %s" % invalid_ordering
-                )
+                raise ValidationError("Invalid filter field: %s" % invalid_ordering)
             else:
                 return valid_ordering
 
@@ -712,9 +638,9 @@ def remove_invalid_fields(self, queryset, fields, view):
         # for each field sent down from the query param,
         # determine if its valid or invalid
         for term in fields:
-            stripped_term = term.lstrip('-')
+            stripped_term = term.lstrip("-")
             # add back the '-' add the end if necessary
-            reverse_sort_term = '' if len(stripped_term) is len(term) else '-'
+            reverse_sort_term = "" if len(stripped_term) is len(term) else "-"
             ordering = self.ordering_for(stripped_term, view)
 
             if ordering:
@@ -735,15 +661,14 @@ def ordering_for(self, term, view):
             return None
 
         serializer = self._get_serializer_class(view)()
-        serializer_chain = term.split('.')
+        serializer_chain = term.split(".")
 
         model_chain = []
 
         for segment in serializer_chain[:-1]:
             field = serializer.get_all_fields().get(segment)
 
-            if not (field and field.source != '*' and
-                    isinstance(field, DynamicRelationField)):
+            if not (field and field.source != "*" and isinstance(field, DynamicRelationField)):
                 return None
 
             model_chain.append(field.source or segment)
@@ -753,22 +678,22 @@ def ordering_for(self, term, view):
         last_segment = serializer_chain[-1]
         last_field = serializer.get_all_fields().get(last_segment)
 
-        if not last_field or last_field.source == '*':
+        if not last_field or last_field.source == "*":
             return None
 
         model_chain.append(last_field.source or last_segment)
 
-        return '__'.join(model_chain)
+        return "__".join(model_chain)
 
     def _is_allowed_term(self, term, view):
-        valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
-        all_fields_allowed = valid_fields is None or valid_fields == '__all__'
+        valid_fields = getattr(view, "ordering_fields", self.ordering_fields)
+        all_fields_allowed = valid_fields is None or valid_fields == "__all__"
 
         return all_fields_allowed or term in valid_fields
 
     def _get_serializer_class(self, view):
         # prefer the overriding method
-        if hasattr(view, 'get_serializer_class'):
+        if hasattr(view, "get_serializer_class"):
             try:
                 serializer_class = view.get_serializer_class()
             except AssertionError:
@@ -777,7 +702,7 @@ def _get_serializer_class(self, view):
                 serializer_class = None
         # use the attribute
         else:
-            serializer_class = getattr(view, 'serializer_class', None)
+            serializer_class = getattr(view, "serializer_class", None)
 
         # neither a method nor an attribute has been specified
         if serializer_class is None:
diff --git a/tests/models.py b/tests/models.py
index f2b495f5..da5d36cd 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -1,32 +1,22 @@
 from django.contrib.contenttypes.fields import GenericForeignKey
 from django.contrib.contenttypes.models import ContentType
+from django.contrib.postgres.fields import JSONField
 from django.db import models
 
 
 class User(models.Model):
     name = models.TextField()
     last_name = models.TextField()
-    groups = models.ManyToManyField('Group', related_name='users')
-    permissions = models.ManyToManyField('Permission', related_name='users')
+    groups = models.ManyToManyField("Group", related_name="users")
+    permissions = models.ManyToManyField("Permission", related_name="users")
     date_of_birth = models.DateField(null=True, blank=True)
     # 'related_name' intentionally left unset in location field below:
-    location = models.ForeignKey(
-        'Location',
-        null=True,
-        blank=True,
-        on_delete=models.CASCADE
-    )
+    location = models.ForeignKey("Location", null=True, blank=True, on_delete=models.CASCADE)
     favorite_pet_type = models.ForeignKey(
-        ContentType,
-        null=True,
-        blank=True,
-        on_delete=models.CASCADE
+        ContentType, null=True, blank=True, on_delete=models.CASCADE
     )
     favorite_pet_id = models.TextField(null=True, blank=True)
-    favorite_pet = GenericForeignKey(
-        'favorite_pet_type',
-        'favorite_pet_id',
-    )
+    favorite_pet = GenericForeignKey("favorite_pet_type", "favorite_pet_id",)
     is_dead = models.NullBooleanField(default=False)
 
 
@@ -38,23 +28,15 @@ class Profile(models.Model):
 
 class Cat(models.Model):
     name = models.TextField()
-    home = models.ForeignKey('Location', on_delete=models.CASCADE)
+    home = models.ForeignKey("Location", on_delete=models.CASCADE)
     backup_home = models.ForeignKey(
-        'Location',
-        related_name='friendly_cats',
-        on_delete=models.CASCADE
+        "Location", related_name="friendly_cats", on_delete=models.CASCADE
     )
     hunting_grounds = models.ManyToManyField(
-        'Location',
-        related_name='annoying_cats',
-        related_query_name='getoffmylawn'
+        "Location", related_name="annoying_cats", related_query_name="getoffmylawn"
     )
     parent = models.ForeignKey(
-        'Cat',
-        null=True,
-        blank=True,
-        related_name='kittens',
-        on_delete=models.CASCADE
+        "Cat", null=True, blank=True, related_name="kittens", on_delete=models.CASCADE
     )
 
 
@@ -76,7 +58,7 @@ class Zebra(models.Model):
 
 class Group(models.Model):
     name = models.TextField(unique=True)
-    permissions = models.ManyToManyField('Permission', related_name='groups')
+    permissions = models.ManyToManyField("Permission", related_name="groups")
 
 
 class Permission(models.Model):
@@ -94,15 +76,11 @@ class Event(models.Model):
     Event model -- Intentionally missing serializer and viewset, so they
     can be added as part of a codelab.
     """
+
     name = models.TextField()
     status = models.TextField(default="current")
-    location = models.ForeignKey(
-        'Location',
-        null=True,
-        blank=True,
-        on_delete=models.CASCADE
-    )
-    users = models.ManyToManyField('User')
+    location = models.ForeignKey("Location", null=True, blank=True, on_delete=models.CASCADE)
+    users = models.ManyToManyField("User")
 
 
 class A(models.Model):
@@ -110,12 +88,12 @@ class A(models.Model):
 
 
 class B(models.Model):
-    a = models.OneToOneField('A', related_name='b', on_delete=models.CASCADE)
+    a = models.OneToOneField("A", related_name="b", on_delete=models.CASCADE)
 
 
 class C(models.Model):
-    b = models.ForeignKey('B', related_name='cs', on_delete=models.CASCADE)
-    d = models.ForeignKey('D', on_delete=models.CASCADE)
+    b = models.ForeignKey("B", related_name="cs", on_delete=models.CASCADE)
+    d = models.ForeignKey("D", on_delete=models.CASCADE)
 
 
 class D(models.Model):
@@ -136,3 +114,8 @@ class Part(models.Model):
     car = models.ForeignKey(Car, on_delete=models.CASCADE)
     name = models.CharField(max_length=60)
     country = models.ForeignKey(Country, on_delete=models.CASCADE)
+
+
+class JsonFieldModel(models.Model):
+    name = models.CharField(max_length=60)
+    some_jsonfield = JSONField(default=dict)
diff --git a/tests/setup.py b/tests/setup.py
index c7e69abb..2ba9e5a5 100644
--- a/tests/setup.py
+++ b/tests/setup.py
@@ -8,12 +8,13 @@
     Event,
     Group,
     Horse,
+    JsonFieldModel,
     Location,
     Part,
     Permission,
     User,
-    Zebra
-    )
+    Zebra,
+)
 
 
 def create_fixture():
@@ -25,31 +26,46 @@ def create_fixture():
     # Create 4 dogs.
     # Create 2 Country
     # Create 1 Car has 2 Parts each from different Country
+    # Create 1 JsonFieldModel
 
     types = [
-        'users', 'groups', 'locations', 'permissions',
-        'events', 'cats', 'dogs', 'horses', 'zebras',
-        'cars', 'countries', 'parts',
+        "users",
+        "groups",
+        "locations",
+        "permissions",
+        "events",
+        "cats",
+        "dogs",
+        "horses",
+        "zebras",
+        "cars",
+        "countries",
+        "parts",
+        "json_field_models",
     ]
-    Fixture = namedtuple('Fixture', types)
+    Fixture = namedtuple("Fixture", types)
 
     fixture = Fixture(
-        users=[], groups=[], locations=[], permissions=[],
-        events=[], cats=[], dogs=[], horses=[], zebras=[],
-        cars=[], countries=[], parts=[]
+        users=[],
+        groups=[],
+        locations=[],
+        permissions=[],
+        events=[],
+        cats=[],
+        dogs=[],
+        horses=[],
+        zebras=[],
+        cars=[],
+        countries=[],
+        parts=[],
+        json_field_models=[],
     )
 
     for i in range(0, 4):
-        fixture.users.append(
-            User.objects.create(
-                name=str(i),
-                last_name=str(i)))
+        fixture.users.append(User.objects.create(name=str(i), last_name=str(i)))
 
     for i in range(0, 4):
-        fixture.permissions.append(
-            Permission.objects.create(
-                name=str(i),
-                code=i))
+        fixture.permissions.append(Permission.objects.create(name=str(i), code=i))
 
     for i in range(0, 2):
         fixture.groups.append(Group.objects.create(name=str(i)))
@@ -58,97 +74,63 @@ def create_fixture():
         fixture.locations.append(Location.objects.create(name=str(i)))
 
     for i in range(0, 2):
-        fixture.cats.append(Cat.objects.create(
-            name=str(i),
-            home_id=fixture.locations[i].id,
-            backup_home_id=(
-                fixture.locations[len(fixture.locations) - 1 - i].id)))
-
-    dogs = [{
-        'name': 'Clifford',
-        'fur_color': 'red',
-        'origin': 'Clifford the big red dog'
-    }, {
-        'name': 'Air-Bud',
-        'fur_color': 'gold',
-        'origin': 'Air Bud 4: Seventh Inning Fetch'
-    }, {
-        'name': 'Spike',
-        'fur_color': 'brown',
-        'origin': 'Rugrats'
-    }, {
-        'name': 'Pluto',
-        'fur_color': 'brown and white',
-        'origin': 'Mickey Mouse'
-    }, {
-        'name': 'Spike',
-        'fur_color': 'light-brown',
-        'origin': 'Tom and Jerry'
-    }]
-
-    horses = [{
-        'name': 'Seabiscuit',
-        'origin': 'LA'
-    }, {
-        'name': 'Secretariat',
-        'origin': 'Kentucky'
-    }]
-
-    zebras = [{
-        'name': 'Ralph',
-        'origin': 'new york'
-    }, {
-        'name': 'Ted',
-        'origin': 'africa'
-    }]
-
-    events = [{
-        'name': 'Event 1',
-        'status': 'archived',
-        'location': 2
-    }, {
-        'name': 'Event 2',
-        'status': 'current',
-        'location': 1
-    }, {
-        'name': 'Event 3',
-        'status': 'current',
-        'location': 1
-    }, {
-        'name': 'Event 4',
-        'status': 'archived',
-        'location': 2
-    }, {
-        'name': 'Event 5',
-        'status': 'current',
-        'location': 2
-    }]
+        fixture.cats.append(
+            Cat.objects.create(
+                name=str(i),
+                home_id=fixture.locations[i].id,
+                backup_home_id=(fixture.locations[len(fixture.locations) - 1 - i].id),
+            )
+        )
+
+    fixture.json_field_models.append(
+        JsonFieldModel.objects.create(
+            name=str(i), some_jsonfield={"value": "string value for icontains testing"}
+        )
+    )
+
+    dogs = [
+        {"name": "Clifford", "fur_color": "red", "origin": "Clifford the big red dog"},
+        {"name": "Air-Bud", "fur_color": "gold", "origin": "Air Bud 4: Seventh Inning Fetch"},
+        {"name": "Spike", "fur_color": "brown", "origin": "Rugrats"},
+        {"name": "Pluto", "fur_color": "brown and white", "origin": "Mickey Mouse"},
+        {"name": "Spike", "fur_color": "light-brown", "origin": "Tom and Jerry"},
+    ]
+
+    horses = [{"name": "Seabiscuit", "origin": "LA"}, {"name": "Secretariat", "origin": "Kentucky"}]
+
+    zebras = [{"name": "Ralph", "origin": "new york"}, {"name": "Ted", "origin": "africa"}]
+
+    events = [
+        {"name": "Event 1", "status": "archived", "location": 2},
+        {"name": "Event 2", "status": "current", "location": 1},
+        {"name": "Event 3", "status": "current", "location": 1},
+        {"name": "Event 4", "status": "archived", "location": 2},
+        {"name": "Event 5", "status": "current", "location": 2},
+    ]
 
     for dog in dogs:
-        fixture.dogs.append(Dog.objects.create(
-            name=dog.get('name'),
-            fur_color=dog.get('fur_color'),
-            origin=dog.get('origin')
-        ))
+        fixture.dogs.append(
+            Dog.objects.create(
+                name=dog.get("name"), fur_color=dog.get("fur_color"), origin=dog.get("origin")
+            )
+        )
 
     for horse in horses:
-        fixture.horses.append(Horse.objects.create(
-            name=horse.get('name'),
-            origin=horse.get('origin')
-        ))
+        fixture.horses.append(
+            Horse.objects.create(name=horse.get("name"), origin=horse.get("origin"))
+        )
 
     for zebra in zebras:
-        fixture.zebras.append(Zebra.objects.create(
-            name=zebra.get('name'),
-            origin=zebra.get('origin')
-        ))
+        fixture.zebras.append(
+            Zebra.objects.create(name=zebra.get("name"), origin=zebra.get("origin"))
+        )
 
     for event in events:
-        fixture.events.append(Event.objects.create(
-            name=event['name'],
-            status=event['status'],
-            location_id=event['location']
-        ))
+        fixture.events.append(
+            Event.objects.create(
+                name=event["name"], status=event["status"], location_id=event["location"]
+            )
+        )
     fixture.events[1].users.add(fixture.users[0])
     fixture.events[1].users.add(fixture.users[1])
     fixture.events[2].users.add(fixture.users[0])
@@ -158,7 +140,7 @@ def create_fixture():
     fixture.events[4].users.add(fixture.users[1])
     fixture.events[4].users.add(fixture.users[2])
 
-    fixture.locations[0].blob = 'here'
+    fixture.locations[0].blob = "here"
     fixture.locations[0].save()
 
     fixture.users[0].location = fixture.locations[0]
@@ -192,47 +174,30 @@ def create_fixture():
     fixture.groups[0].permissions.add(fixture.permissions[0])
     fixture.groups[1].permissions.add(fixture.permissions[1])
 
-    countries = [{
-        'id': 1,
-        'name': 'United States',
-        'short_name': 'US',
-    }, {
-        'id': 2,
-        'name': 'China',
-        'short_name': 'CN',
-    }]
-
-    cars = [{
-        'id': 1,
-        'name': 'Porshe',
-        'country': 1
-    }]
-
-    parts = [{
-        'car': 1,
-        'name': 'wheel',
-        'country': 1
-    }, {
-        'car': 1,
-        'name': 'tire',
-        'country': 2
-    }]
+    countries = [
+        {"id": 1, "name": "United States", "short_name": "US",},
+        {"id": 2, "name": "China", "short_name": "CN",},
+    ]
+
+    cars = [{"id": 1, "name": "Porshe", "country": 1}]
+
+    parts = [{"car": 1, "name": "wheel", "country": 1}, {"car": 1, "name": "tire", "country": 2}]
 
     for country in countries:
         fixture.countries.append(Country.objects.create(**country))
 
     for car in cars:
-        fixture.cars.append(Car.objects.create(
-            id=car.get('id'),
-            name=car.get('name'),
-            country_id=car.get('country')
-        ))
+        fixture.cars.append(
+            Car.objects.create(
+                id=car.get("id"), name=car.get("name"), country_id=car.get("country")
+            )
+        )
 
     for part in parts:
-        fixture.parts.append(Part.objects.create(
-            car_id=part.get('car'),
-            name=part.get('name'),
-            country_id=part.get('country')
-        ))
+        fixture.parts.append(
+            Part.objects.create(
+                car_id=part.get("car"), name=part.get("name"), country_id=part.get("country")
+            )
+        )
 
     return fixture
diff --git a/tests/test_generic.py b/tests/test_generic.py
index 05e74043..ab0b402f 100644
--- a/tests/test_generic.py
+++ b/tests/test_generic.py
@@ -1,16 +1,14 @@
 import json
 
-from rest_framework.test import APITestCase
-
 from dynamic_rest.fields import DynamicGenericRelationField
 from dynamic_rest.routers import DynamicRouter
-from tests.models import User, Zebra
+from rest_framework.test import APITestCase
+from tests.models import JsonFieldModel, User, Zebra
 from tests.serializers import UserSerializer
 from tests.setup import create_fixture
 
 
 class TestGenericRelationFieldAPI(APITestCase):
-
     def setUp(self):
         self.fixture = create_fixture()
         f = self.fixture
@@ -34,61 +32,47 @@ def test_id_only(self):
             }
         ```
         """
-        url = (
-            '/users/?include[]=favorite_pet'
-            '&filter{favorite_pet_id.isnull}=false'
-        )
+        url = "/users/?include[]=favorite_pet" "&filter{favorite_pet_id.isnull}=false"
         response = self.client.get(url)
         self.assertEqual(200, response.status_code)
-        content = json.loads(response.content.decode('utf-8'))
-        self.assertTrue(
-            all(
-                [_['favorite_pet'] for _ in content['users']]
-            )
-        )
-        self.assertFalse('cats' in content)
-        self.assertFalse('dogs' in content)
-        self.assertTrue('type' in content['users'][0]['favorite_pet'])
-        self.assertTrue('id' in content['users'][0]['favorite_pet'])
+        content = json.loads(response.content.decode("utf-8"))
+        self.assertTrue(all([_["favorite_pet"] for _ in content["users"]]))
+        self.assertFalse("cats" in content)
+        self.assertFalse("dogs" in content)
+        self.assertTrue("type" in content["users"][0]["favorite_pet"])
+        self.assertTrue("id" in content["users"][0]["favorite_pet"])
 
     def test_sideload(self):
-        url = (
-            '/users/?include[]=favorite_pet.'
-            '&filter{favorite_pet_id.isnull}=false'
-        )
+        url = "/users/?include[]=favorite_pet." "&filter{favorite_pet_id.isnull}=false"
         response = self.client.get(url)
         self.assertEqual(200, response.status_code)
-        content = json.loads(response.content.decode('utf-8'))
-        self.assertTrue(
-            all(
-                [_['favorite_pet'] for _ in content['users']]
-            )
-        )
-        self.assertTrue('cats' in content)
-        self.assertEqual(2, len(content['cats']))
-        self.assertTrue('dogs' in content)
-        self.assertEqual(1, len(content['dogs']))
-        self.assertTrue('type' in content['users'][0]['favorite_pet'])
-        self.assertTrue('id' in content['users'][0]['favorite_pet'])
+        content = json.loads(response.content.decode("utf-8"))
+        self.assertTrue(all([_["favorite_pet"] for _ in content["users"]]))
+        self.assertTrue("cats" in content)
+        self.assertEqual(2, len(content["cats"]))
+        self.assertTrue("dogs" in content)
+        self.assertEqual(1, len(content["dogs"]))
+        self.assertTrue("type" in content["users"][0]["favorite_pet"])
+        self.assertTrue("id" in content["users"][0]["favorite_pet"])
 
     def test_multi_sideload_include(self):
         url = (
-            '/cars/1/?include[]=name&include[]=country.short_name'
-            '&include[]=parts.name&include[]=parts.country.name'
+            "/cars/1/?include[]=name&include[]=country.short_name"
+            "&include[]=parts.name&include[]=parts.country.name"
         )
         response = self.client.get(url)
         self.assertEqual(200, response.status_code)
-        content = json.loads(response.content.decode('utf-8'))
-        self.assertTrue('countries' in content)
+        content = json.loads(response.content.decode("utf-8"))
+        self.assertTrue("countries" in content)
 
         country = None
-        for _ in content['countries']:
-            if _['id'] == 1:
+        for _ in content["countries"]:
+            if _["id"] == 1:
                 country = _
 
         self.assertTrue(country)
-        self.assertTrue('short_name' in country)
-        self.assertTrue('name' in country)
+        self.assertTrue("short_name" in country)
+        self.assertTrue("name" in country)
 
     def test_query_counts(self):
         # NOTE: Django doesn't seem to prefetch ContentType objects
@@ -96,15 +80,12 @@ def test_query_counts(self):
         #       this call could do 5 SQL queries if the Cat and Dog
         #       ContentType objects haven't been cached.
         with self.assertNumQueries(3):
-            url = (
-                '/users/?include[]=favorite_pet.'
-                '&filter{favorite_pet_id.isnull}=false'
-            )
+            url = "/users/?include[]=favorite_pet." "&filter{favorite_pet_id.isnull}=false"
             response = self.client.get(url)
             self.assertEqual(200, response.status_code)
 
         with self.assertNumQueries(3):
-            url = '/users/?include[]=favorite_pet.'
+            url = "/users/?include[]=favorite_pet."
             response = self.client.get(url)
             self.assertEqual(200, response.status_code)
 
@@ -113,10 +94,7 @@ def test_unknown_resource(self):
         which there is no known canonical serializer.
         """
 
-        zork = Zebra.objects.create(
-            name='Zork',
-            origin='San Francisco Zoo'
-        )
+        zork = Zebra.objects.create(name="Zork", origin="San Francisco Zoo")
 
         user = self.fixture.users[0]
         user.favorite_pet = zork
@@ -124,26 +102,23 @@ def test_unknown_resource(self):
 
         self.assertIsNone(DynamicRouter.get_canonical_serializer(Zebra))
 
-        url = '/users/%s/?include[]=favorite_pet' % user.pk
+        url = "/users/%s/?include[]=favorite_pet" % user.pk
         response = self.client.get(url)
         self.assertEqual(200, response.status_code)
-        content = json.loads(response.content.decode('utf-8'))
-        self.assertTrue('user' in content)
-        self.assertFalse('zebras' in content)  # Not sideloaded
-        user_obj = content['user']
-        self.assertTrue('favorite_pet' in user_obj)
-        self.assertEqual('Zebra', user_obj['favorite_pet']['type'])
-        self.assertEqual(zork.pk, user_obj['favorite_pet']['id'])
+        content = json.loads(response.content.decode("utf-8"))
+        self.assertTrue("user" in content)
+        self.assertFalse("zebras" in content)  # Not sideloaded
+        user_obj = content["user"]
+        self.assertTrue("favorite_pet" in user_obj)
+        self.assertEqual("Zebra", user_obj["favorite_pet"]["type"])
+        self.assertEqual(zork.pk, user_obj["favorite_pet"]["id"])
 
     def test_dgrf_with_requires_raises(self):
         with self.assertRaises(Exception):
-            DynamicGenericRelationField(requires=['foo', 'bar'])
+            DynamicGenericRelationField(requires=["foo", "bar"])
 
     def test_if_field_inclusion_then_error(self):
-        url = (
-            '/users/?include[]=favorite_pet.name'
-            '&filter{favorite_pet_id.isnull}=false'
-        )
+        url = "/users/?include[]=favorite_pet.name" "&filter{favorite_pet_id.isnull}=false"
         response = self.client.get(url)
         self.assertEqual(400, response.status_code)
 
@@ -154,47 +129,45 @@ def test_patch_resource(self):
         """
         user = self.fixture.users[0]
 
-        url = '/users/%s/?include[]=favorite_pet.' % user.pk
+        url = "/users/%s/?include[]=favorite_pet." % user.pk
         response = self.client.patch(
             url,
-            json.dumps({
-                'id': user.id,
-                'favorite_pet': {
-                    'type': 'dog',
-                    'id': 1
-                }
-            }),
-            content_type='application/json'
+            json.dumps({"id": user.id, "favorite_pet": {"type": "dog", "id": 1}}),
+            content_type="application/json",
         )
         self.assertEqual(200, response.status_code)
-        content = json.loads(response.content.decode('utf-8'))
-        self.assertTrue('user' in content)
-        self.assertFalse('cats' in content)
-        self.assertTrue('dogs' in content)
-        self.assertEqual(1, content['dogs'][0]['id'])
+        content = json.loads(response.content.decode("utf-8"))
+        self.assertTrue("user" in content)
+        self.assertFalse("cats" in content)
+        self.assertTrue("dogs" in content)
+        self.assertEqual(1, content["dogs"][0]["id"])
 
     def test_non_deferred_generic_field(self):
         class FooUserSerializer(UserSerializer):
-
             class Meta:
                 model = User
-                name = 'user'
+                name = "user"
                 fields = (
-                    'id',
-                    'favorite_pet',
+                    "id",
+                    "favorite_pet",
                 )
 
-        user = User.objects.filter(
-            favorite_pet_id__isnull=False
-        ).prefetch_related(
-            'favorite_pet'
-        ).first()
+        user = (
+            User.objects.filter(favorite_pet_id__isnull=False)
+            .prefetch_related("favorite_pet")
+            .first()
+        )
 
-        data = FooUserSerializer(user, envelope=True).data['user']
+        data = FooUserSerializer(user, envelope=True).data["user"]
         self.assertIsNotNone(data)
-        self.assertTrue('favorite_pet' in data)
-        self.assertTrue(isinstance(data['favorite_pet'], dict))
-        self.assertEqual(
-            set(['id', 'type']),
-            set(data['favorite_pet'].keys())
-        )
+        self.assertTrue("favorite_pet" in data)
+        self.assertTrue(isinstance(data["favorite_pet"], dict))
+        self.assertEqual(set(["id", "type"]), set(data["favorite_pet"].keys()))
+
+    def test_jsonfield_filter_recurse(self):
+        """
+        make the filter work with JSONField
+        """
+        url = "/json_field_models/?&filter{some_jsonfield.value.icontains}=icontains"
+        response = self.client.get(url)
+        self.assertEqual(200, response.status_code)