diff --git a/example/home/models.py b/example/home/models.py index 9a3029f0..0b9e083e 100644 --- a/example/home/models.py +++ b/example/home/models.py @@ -18,6 +18,7 @@ from grapple.helpers import ( register_paginated_query_field, + register_plural_query_field, register_query_field, register_singular_query_field, ) @@ -162,6 +163,11 @@ class BlogPageRelatedLink(Orderable): @register_snippet +@register_singular_query_field("person", {"name": graphene.String()}) +@register_plural_query_field("people", {"job": graphene.String()}) +@register_plural_query_field( + "people_paginated", {"job": graphene.String()}, paginated=True +) class Person(models.Model): name = models.CharField(max_length=255) job = models.CharField(max_length=255) diff --git a/example/home/test/test_person.py b/example/home/test/test_person.py new file mode 100644 index 00000000..c2143a33 --- /dev/null +++ b/example/home/test/test_person.py @@ -0,0 +1,112 @@ +from home.factories import PersonFactory + +from example.tests.test_grapple import BaseGrappleTest + + +class PersonTest(BaseGrappleTest): + def setUp(self): + super().setUp() + # Create person + self.person1 = PersonFactory(name="Chuck Norris", job="Roundhouse Kicker") + self.person2 = PersonFactory(name="Rory", job="Dog") + + def validate_person(self, person): + # Check all the fields + self.assertTrue(isinstance(person["id"], str)) + self.assertTrue(isinstance(person["name"], str)) + self.assertTrue(isinstance(person["job"], str)) + + def test_people_query(self): + query = """ + { + people { + id + name + job + } + } + """ + executed = self.client.execute(query) + person = executed["data"]["people"][0] + + # Check all the fields + self.validate_person(person) + + def test_people_paginated_query(self): + query = """ + { + peoplePaginated { + items { + id + name + job + } + pagination { + total + count + } + } + } + """ + executed = self.client.execute(query) + person = executed["data"]["peoplePaginated"]["items"][0] + + # Check all the fields + self.validate_person(person) + + def test_people_query_fields(self): + query = """ + query($job: String) { + people(job: $job) { + id + name + job + } + } + """ + executed = self.client.execute(query, variables={"job": self.person1.job}) + person = executed["data"]["people"][0] + + # Check all the fields + self.validate_person(person) + self.assertEqual(person["name"], self.person1.name) + + def test_people_paginated_query_fields(self): + query = """ + query($job: String) { + peoplePaginated(job: $job) { + items { + id + name + job + } + pagination { + total + count + } + } + } + """ + executed = self.client.execute(query, variables={"job": self.person2.job}) + person = executed["data"]["peoplePaginated"]["items"][0] + + # Check all the fields + self.validate_person(person) + self.assertEqual(person["name"], self.person2.name) + + def test_person_single_query(self): + query = """ + query($name: String) { + person(name: $name) { + id + name + job + } + } + """ + executed = self.client.execute(query, variables={"name": self.person1.name}) + person = executed["data"]["person"] + + # Check all the fields + self.validate_person(person) + self.assertEqual(person["name"], self.person1.name) diff --git a/grapple/helpers.py b/grapple/helpers.py index b3ced54f..7511ca87 100644 --- a/grapple/helpers.py +++ b/grapple/helpers.py @@ -14,6 +14,32 @@ field_middlewares = {} +def _add_default_query_params(cls, query_params): + query_params.update( + { + "id": graphene.Int(), + "order": graphene.Argument( + graphene.String, + description=_("Use the Django QuerySet order_by format."), + ), + } + ) + if issubclass(cls, Page): + query_params.update( + { + "slug": graphene.Argument( + graphene.String, description=_("The page slug.") + ), + "url_path": graphene.Argument( + graphene.String, description=_("The url path.") + ), + "token": graphene.Argument( + graphene.String, description=_("The preview token.") + ), + } + ) + + def register_streamfield_block(cls): base_block = None for block_class in inspect.getmro(cls): @@ -280,23 +306,19 @@ def resolve_plural(self, _, info, **kwargs): def register_singular_query_field( - field_name, query_params=None, required=False, middleware=None + field_name, + query_params=None, + required=False, + middleware=None, + keep_default_query_params=False, ): def inner(cls): - field_type = lambda: registry.models[cls] # noqa: E731 - field_query_params = query_params + nonlocal query_params + if query_params is None or keep_default_query_params: + query_params = query_params.copy() if query_params else {} + _add_default_query_params(cls, query_params) - if field_query_params is None: - field_query_params = { - "order": graphene.Argument( - graphene.String, - description=_("Use the Django QuerySet order_by format."), - ), - } - if issubclass(cls, Page): - field_query_params["token"] = graphene.Argument( - graphene.String, description=_("The preview token.") - ) + field_type = lambda: registry.models[cls] # noqa: E731 def Mixin(): # Generic methods to get all and query one model instance. @@ -331,7 +353,7 @@ def resolve_singular(self, _, info, **kwargs): setattr( schema, field_name, - graphene.Field(singular_field_type, **field_query_params), + graphene.Field(singular_field_type, **query_params), ) setattr( @@ -347,3 +369,75 @@ def resolve_singular(self, _, info, **kwargs): register_field_middleware(field_name, middleware) return inner + + +def register_plural_query_field( + plural_field_name, + query_params=None, + required=False, + item_required=False, + middleware=None, + paginated=False, + keep_default_query_params=False, +): + if paginated: + from .types.structures import PaginatedQuerySet + from .utils import resolve_paginated_queryset + else: + from .types.structures import QuerySetList + from .utils import resolve_queryset + + def inner(cls): + nonlocal query_params + if query_params is None or keep_default_query_params: + query_params = query_params.copy() if query_params else {} + _add_default_query_params(cls, query_params) + + field_type = lambda: registry.models[cls] # noqa: E731 + + def Mixin(): + # Generic methods to get all model instances. + def resolve_plural(self, _, info, **kwargs): + qs = cls.objects + if issubclass(cls, Page): + qs = qs.live().public() + if "order" not in kwargs: + kwargs["order"] = "-first_published_at" + elif "order" not in kwargs: + kwargs["order"] = "pk" + + if paginated: + return resolve_paginated_queryset(qs.all(), info, **kwargs) + else: + return resolve_queryset(qs.all(), info, **kwargs) + + # Create schema and add resolve methods + schema = type(cls.__name__ + "Query", (), {}) + + plural_field_type = field_type + if item_required: + plural_field_type = graphene.NonNull(field_type) + + if paginated: + qsl = PaginatedQuerySet( + plural_field_type, cls, required=required, **query_params + ) + else: + qsl = QuerySetList(plural_field_type, required=required, **query_params) + setattr(schema, plural_field_name, qsl) + + setattr( + schema, + "resolve_" + plural_field_name, + MethodType(resolve_plural, schema), + ) + return schema + + # Send schema to Grapple schema. + register_graphql_schema(Mixin()) + return cls + + if middleware is not None: + register_field_middleware(plural_field_name, middleware) + + return inner diff --git a/grapple/types/pages.py b/grapple/types/pages.py index 65ee6638..e2c7856c 100644 --- a/grapple/types/pages.py +++ b/grapple/types/pages.py @@ -268,16 +268,15 @@ class Mixin: ) # Return all pages in site, ideally specific. - def resolve_pages(self, info, **kwargs): + def resolve_pages(self, info, *, in_site=False, content_type=None, **kwargs): pages = ( WagtailPage.objects.live().public().filter(depth__gt=1).specific() ) # no need to the root page - if kwargs.get("in_site", False): + if in_site: site = Site.find_for_request(info.context) pages = pages.in_site(site) - content_type = kwargs.pop("content_type", None) if content_type: app_label, model = content_type.strip().lower().split(".") try: diff --git a/grapple/utils.py b/grapple/utils.py index fc97ea58..0de888b1 100644 --- a/grapple/utils.py +++ b/grapple/utils.py @@ -57,11 +57,11 @@ def resolve_queryset( :param collection: Use Wagtail's collection id to filter images or documents :type collection: int """ - + filters = kwargs.copy() if id is not None: - qs = qs.filter(pk=id) - else: - qs = qs.all() + filters["pk"] = id + + qs = qs.filter(**filters) if filters else qs.all() if id is None and search_query: # Check if the queryset is searchable using Wagtail search. @@ -77,7 +77,7 @@ def resolve_queryset( return _sliced_queryset(qs, limit, offset) if order is not None: - qs = qs.order_by(*map(lambda x: x.strip(), order.split(","))) + qs = qs.order_by(*[x.strip() for x in order.split(",")]) if collection is not None: try: @@ -148,10 +148,11 @@ def resolve_paginated_queryset( int(per_page or grapple_settings.PAGE_SIZE), grapple_settings.MAX_PAGE_SIZE ) + filters = kwargs.copy() if id is not None: - qs = qs.filter(pk=id) - else: - qs = qs.all() + filters["pk"] = id + + qs = qs.filter(**filters) if filters else qs.all() if id is None and search_query: # Check if the queryset is searchable using Wagtail search. @@ -167,7 +168,7 @@ def resolve_paginated_queryset( return get_paginated_result(results, page, per_page) if order is not None: - qs = qs.order_by(*map(lambda x: x.strip(), order.split(","))) + qs = qs.order_by(*[x.strip() for x in order.split(",")]) return get_paginated_result(qs, page, per_page)