From 4ab19b4ec63476d04c7477d98fa3e90cc2009dda Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Sat, 9 Apr 2022 10:28:14 +0100 Subject: [PATCH 1/3] Enable query_params on plural queries #219 --- example/home/test/test_advert.py | 16 ++++++++++++++++ grapple/helpers.py | 4 +++- grapple/types/pages.py | 4 ++-- grapple/utils.py | 4 ++-- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/example/home/test/test_advert.py b/example/home/test/test_advert.py index 7da6a010..09a6c252 100644 --- a/example/home/test/test_advert.py +++ b/example/home/test/test_advert.py @@ -31,6 +31,22 @@ def test_advert_all_query(self): # Check all the fields self.validate_advert(advert) + def test_advert_all_query_fields(self): + query = """ + query($url: String) { + adverts(url: $url) { + id + url + text + } + } + """ + executed = self.client.execute(query, variables={"url": self.advert.url}) + advert = executed["data"]["adverts"][0] + + # Check all the fields + self.validate_advert(advert) + def test_advert_single_query(self): query = """ query($url: String) { diff --git a/grapple/helpers.py b/grapple/helpers.py index b3ced54f..d3e4bc0a 100644 --- a/grapple/helpers.py +++ b/grapple/helpers.py @@ -141,7 +141,9 @@ def resolve_plural(self, _, info, **kwargs): setattr( schema, plural_field_name, - QuerySetList(plural_field_type, required=plural_required), + QuerySetList( + plural_field_type, required=plural_required, **field_query_params + ), ) setattr( diff --git a/grapple/types/pages.py b/grapple/types/pages.py index 65ee6638..dfba079a 100644 --- a/grapple/types/pages.py +++ b/grapple/types/pages.py @@ -268,12 +268,12 @@ class Mixin: ) # Return all pages in site, ideally specific. - def resolve_pages(self, info, **kwargs): + def resolve_pages(self, info, in_site=False, **kwargs): pages = ( WagtailPage.objects.live().public().filter(depth__gt=1).specific() ) # no need to the root page - if kwargs.get("in_site", False): + if in_site: site = Site.find_for_request(info.context) pages = pages.in_site(site) diff --git a/grapple/utils.py b/grapple/utils.py index fc97ea58..0e37b96c 100644 --- a/grapple/utils.py +++ b/grapple/utils.py @@ -58,10 +58,10 @@ def resolve_queryset( :type collection: int """ + if kwargs: + qs = qs.filter(**kwargs) if id is not None: qs = qs.filter(pk=id) - else: - qs = qs.all() if id is None and search_query: # Check if the queryset is searchable using Wagtail search. From 290e5b2f11cbbb3771c831c895d7420c9bea438c Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Sat, 9 Apr 2022 10:42:30 +0100 Subject: [PATCH 2/3] Add content_type to explicit kwargs --- grapple/types/pages.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/grapple/types/pages.py b/grapple/types/pages.py index dfba079a..e2c7856c 100644 --- a/grapple/types/pages.py +++ b/grapple/types/pages.py @@ -268,7 +268,7 @@ class Mixin: ) # Return all pages in site, ideally specific. - def resolve_pages(self, info, in_site=False, **kwargs): + def resolve_pages(self, info, *, in_site=False, content_type=None, **kwargs): pages = ( WagtailPage.objects.live().public().filter(depth__gt=1).specific() ) # no need to the root page @@ -277,7 +277,6 @@ def resolve_pages(self, info, in_site=False, **kwargs): site = Site.find_for_request(info.context) pages = pages.in_site(site) - content_type = kwargs.pop("content_type", None) if content_type: app_label, model = content_type.strip().lower().split(".") try: From 3f4af6283dee35bef2ce243515314ce4698b32fa Mon Sep 17 00:00:00 2001 From: Ellis Percival Date: Sat, 9 Apr 2022 13:20:41 +0100 Subject: [PATCH 3/3] Add register_plural_query_field helper. Make default query params more consistent. --- example/home/models.py | 6 ++ example/home/test/test_advert.py | 16 ---- example/home/test/test_person.py | 112 +++++++++++++++++++++++++++ grapple/helpers.py | 128 ++++++++++++++++++++++++++----- grapple/utils.py | 19 ++--- 5 files changed, 238 insertions(+), 43 deletions(-) create mode 100644 example/home/test/test_person.py diff --git a/example/home/models.py b/example/home/models.py index 9a3029f0..0b9e083e 100644 --- a/example/home/models.py +++ b/example/home/models.py @@ -18,6 +18,7 @@ from grapple.helpers import ( register_paginated_query_field, + register_plural_query_field, register_query_field, register_singular_query_field, ) @@ -162,6 +163,11 @@ class BlogPageRelatedLink(Orderable): @register_snippet +@register_singular_query_field("person", {"name": graphene.String()}) +@register_plural_query_field("people", {"job": graphene.String()}) +@register_plural_query_field( + "people_paginated", {"job": graphene.String()}, paginated=True +) class Person(models.Model): name = models.CharField(max_length=255) job = models.CharField(max_length=255) diff --git a/example/home/test/test_advert.py b/example/home/test/test_advert.py index 09a6c252..7da6a010 100644 --- a/example/home/test/test_advert.py +++ b/example/home/test/test_advert.py @@ -31,22 +31,6 @@ def test_advert_all_query(self): # Check all the fields self.validate_advert(advert) - def test_advert_all_query_fields(self): - query = """ - query($url: String) { - adverts(url: $url) { - id - url - text - } - } - """ - executed = self.client.execute(query, variables={"url": self.advert.url}) - advert = executed["data"]["adverts"][0] - - # Check all the fields - self.validate_advert(advert) - def test_advert_single_query(self): query = """ query($url: String) { diff --git a/example/home/test/test_person.py b/example/home/test/test_person.py new file mode 100644 index 00000000..c2143a33 --- /dev/null +++ b/example/home/test/test_person.py @@ -0,0 +1,112 @@ +from home.factories import PersonFactory + +from example.tests.test_grapple import BaseGrappleTest + + +class PersonTest(BaseGrappleTest): + def setUp(self): + super().setUp() + # Create person + self.person1 = PersonFactory(name="Chuck Norris", job="Roundhouse Kicker") + self.person2 = PersonFactory(name="Rory", job="Dog") + + def validate_person(self, person): + # Check all the fields + self.assertTrue(isinstance(person["id"], str)) + self.assertTrue(isinstance(person["name"], str)) + self.assertTrue(isinstance(person["job"], str)) + + def test_people_query(self): + query = """ + { + people { + id + name + job + } + } + """ + executed = self.client.execute(query) + person = executed["data"]["people"][0] + + # Check all the fields + self.validate_person(person) + + def test_people_paginated_query(self): + query = """ + { + peoplePaginated { + items { + id + name + job + } + pagination { + total + count + } + } + } + """ + executed = self.client.execute(query) + person = executed["data"]["peoplePaginated"]["items"][0] + + # Check all the fields + self.validate_person(person) + + def test_people_query_fields(self): + query = """ + query($job: String) { + people(job: $job) { + id + name + job + } + } + """ + executed = self.client.execute(query, variables={"job": self.person1.job}) + person = executed["data"]["people"][0] + + # Check all the fields + self.validate_person(person) + self.assertEqual(person["name"], self.person1.name) + + def test_people_paginated_query_fields(self): + query = """ + query($job: String) { + peoplePaginated(job: $job) { + items { + id + name + job + } + pagination { + total + count + } + } + } + """ + executed = self.client.execute(query, variables={"job": self.person2.job}) + person = executed["data"]["peoplePaginated"]["items"][0] + + # Check all the fields + self.validate_person(person) + self.assertEqual(person["name"], self.person2.name) + + def test_person_single_query(self): + query = """ + query($name: String) { + person(name: $name) { + id + name + job + } + } + """ + executed = self.client.execute(query, variables={"name": self.person1.name}) + person = executed["data"]["person"] + + # Check all the fields + self.validate_person(person) + self.assertEqual(person["name"], self.person1.name) diff --git a/grapple/helpers.py b/grapple/helpers.py index d3e4bc0a..7511ca87 100644 --- a/grapple/helpers.py +++ b/grapple/helpers.py @@ -14,6 +14,32 @@ field_middlewares = {} +def _add_default_query_params(cls, query_params): + query_params.update( + { + "id": graphene.Int(), + "order": graphene.Argument( + graphene.String, + description=_("Use the Django QuerySet order_by format."), + ), + } + ) + if issubclass(cls, Page): + query_params.update( + { + "slug": graphene.Argument( + graphene.String, description=_("The page slug.") + ), + "url_path": graphene.Argument( + graphene.String, description=_("The url path.") + ), + "token": graphene.Argument( + graphene.String, description=_("The preview token.") + ), + } + ) + + def register_streamfield_block(cls): base_block = None for block_class in inspect.getmro(cls): @@ -141,9 +167,7 @@ def resolve_plural(self, _, info, **kwargs): setattr( schema, plural_field_name, - QuerySetList( - plural_field_type, required=plural_required, **field_query_params - ), + QuerySetList(plural_field_type, required=plural_required), ) setattr( @@ -282,23 +306,19 @@ def resolve_plural(self, _, info, **kwargs): def register_singular_query_field( - field_name, query_params=None, required=False, middleware=None + field_name, + query_params=None, + required=False, + middleware=None, + keep_default_query_params=False, ): def inner(cls): - field_type = lambda: registry.models[cls] # noqa: E731 - field_query_params = query_params + nonlocal query_params + if query_params is None or keep_default_query_params: + query_params = query_params.copy() if query_params else {} + _add_default_query_params(cls, query_params) - if field_query_params is None: - field_query_params = { - "order": graphene.Argument( - graphene.String, - description=_("Use the Django QuerySet order_by format."), - ), - } - if issubclass(cls, Page): - field_query_params["token"] = graphene.Argument( - graphene.String, description=_("The preview token.") - ) + field_type = lambda: registry.models[cls] # noqa: E731 def Mixin(): # Generic methods to get all and query one model instance. @@ -333,7 +353,7 @@ def resolve_singular(self, _, info, **kwargs): setattr( schema, field_name, - graphene.Field(singular_field_type, **field_query_params), + graphene.Field(singular_field_type, **query_params), ) setattr( @@ -349,3 +369,75 @@ def resolve_singular(self, _, info, **kwargs): register_field_middleware(field_name, middleware) return inner + + +def register_plural_query_field( + plural_field_name, + query_params=None, + required=False, + item_required=False, + middleware=None, + paginated=False, + keep_default_query_params=False, +): + if paginated: + from .types.structures import PaginatedQuerySet + from .utils import resolve_paginated_queryset + else: + from .types.structures import QuerySetList + from .utils import resolve_queryset + + def inner(cls): + nonlocal query_params + if query_params is None or keep_default_query_params: + query_params = query_params.copy() if query_params else {} + _add_default_query_params(cls, query_params) + + field_type = lambda: registry.models[cls] # noqa: E731 + + def Mixin(): + # Generic methods to get all model instances. + def resolve_plural(self, _, info, **kwargs): + qs = cls.objects + if issubclass(cls, Page): + qs = qs.live().public() + if "order" not in kwargs: + kwargs["order"] = "-first_published_at" + elif "order" not in kwargs: + kwargs["order"] = "pk" + + if paginated: + return resolve_paginated_queryset(qs.all(), info, **kwargs) + else: + return resolve_queryset(qs.all(), info, **kwargs) + + # Create schema and add resolve methods + schema = type(cls.__name__ + "Query", (), {}) + + plural_field_type = field_type + if item_required: + plural_field_type = graphene.NonNull(field_type) + + if paginated: + qsl = PaginatedQuerySet( + plural_field_type, cls, required=required, **query_params + ) + else: + qsl = QuerySetList(plural_field_type, required=required, **query_params) + setattr(schema, plural_field_name, qsl) + + setattr( + schema, + "resolve_" + plural_field_name, + MethodType(resolve_plural, schema), + ) + return schema + + # Send schema to Grapple schema. + register_graphql_schema(Mixin()) + return cls + + if middleware is not None: + register_field_middleware(plural_field_name, middleware) + + return inner diff --git a/grapple/utils.py b/grapple/utils.py index 0e37b96c..0de888b1 100644 --- a/grapple/utils.py +++ b/grapple/utils.py @@ -57,11 +57,11 @@ def resolve_queryset( :param collection: Use Wagtail's collection id to filter images or documents :type collection: int """ - - if kwargs: - qs = qs.filter(**kwargs) + filters = kwargs.copy() if id is not None: - qs = qs.filter(pk=id) + filters["pk"] = id + + qs = qs.filter(**filters) if filters else qs.all() if id is None and search_query: # Check if the queryset is searchable using Wagtail search. @@ -77,7 +77,7 @@ def resolve_queryset( return _sliced_queryset(qs, limit, offset) if order is not None: - qs = qs.order_by(*map(lambda x: x.strip(), order.split(","))) + qs = qs.order_by(*[x.strip() for x in order.split(",")]) if collection is not None: try: @@ -148,10 +148,11 @@ def resolve_paginated_queryset( int(per_page or grapple_settings.PAGE_SIZE), grapple_settings.MAX_PAGE_SIZE ) + filters = kwargs.copy() if id is not None: - qs = qs.filter(pk=id) - else: - qs = qs.all() + filters["pk"] = id + + qs = qs.filter(**filters) if filters else qs.all() if id is None and search_query: # Check if the queryset is searchable using Wagtail search. @@ -167,7 +168,7 @@ def resolve_paginated_queryset( return get_paginated_result(results, page, per_page) if order is not None: - qs = qs.order_by(*map(lambda x: x.strip(), order.split(","))) + qs = qs.order_by(*[x.strip() for x in order.split(",")]) return get_paginated_result(qs, page, per_page)