diff --git a/grapple/types/structures.py b/grapple/types/structures.py index 9c107ddb..1bf8f04b 100644 --- a/grapple/types/structures.py +++ b/grapple/types/structures.py @@ -21,6 +21,21 @@ def parse_literal(ast, _variables=None): return return_value +class SearchOperatorEnum(graphene.Enum): + """ + Enum for search operator. + """ + + AND = "and" + OR = "or" + + def __str__(self): + # the core search parser expects the operator to be a string. + # the default __str__ returns SearchOperatorEnum.AND/OR, + # this __str__ returns the value and/or for compatibility. + return self.value + + class QuerySetList(graphene.List): """ List type with arguments used by Django's query sets. @@ -32,6 +47,8 @@ class QuerySetList(graphene.List): * ``limit`` * ``offset`` * ``search_query`` + * ``search_operator`` + * ``search_fields`` * ``order`` :param enable_in_menu: Enable in_menu filter. @@ -42,6 +59,10 @@ class QuerySetList(graphene.List): :type enable_offset: bool :param enable_search: Enable search query argument. :type enable_search: bool + :param enable_search_fields: Enable search fields argument, enable_search must also be True + :type enable_search_fields: bool + :param enable_search_operator: Enable search operator argument, enable_search must also be True + :type enable_search_operator: bool :param enable_order: Enable ordering via query argument. :type enable_order: bool """ @@ -50,8 +71,10 @@ def __init__(self, of_type, *args, **kwargs): enable_in_menu = kwargs.pop("enable_in_menu", False) enable_limit = kwargs.pop("enable_limit", True) enable_offset = kwargs.pop("enable_offset", True) - enable_search = kwargs.pop("enable_search", True) enable_order = kwargs.pop("enable_order", True) + enable_search = kwargs.pop("enable_search", True) + enable_search_fields = kwargs.pop("enable_search_fields", True) + enable_search_operator = kwargs.pop("enable_search_operator", True) # Check if the type is a Django model type. Do not perform the # check if value is lazy. @@ -106,6 +129,22 @@ def __init__(self, of_type, *args, **kwargs): graphene.String, description=_("Filter the results using Wagtail's search."), ) + if enable_search_operator: + kwargs["search_operator"] = graphene.Argument( + SearchOperatorEnum, + description=_( + "Specify search operator (and/or), see: https://docs.wagtail.org/en/stable/topics/search/searching.html#search-operator" + ), + default_value="and", + ) + + if enable_search_fields: + kwargs["search_fields"] = graphene.Argument( + graphene.List(graphene.String), + description=_( + "A list of fields to search in. see: https://docs.wagtail.org/en/stable/topics/search/searching.html#specifying-the-fields-to-search" + ), + ) if "id" not in kwargs: kwargs["id"] = graphene.Argument(graphene.ID, description=_("Filter by ID")) @@ -152,23 +191,31 @@ def PaginatedQuerySet(of_type, type_class, **kwargs): """ Paginated QuerySet type with arguments used by Django's query sets. - This type setts the following arguments on itself: + This type sets the following arguments on itself: * ``id`` * ``in_menu`` * ``page`` * ``per_page`` * ``search_query`` + * ``search_operator`` + * ``search_fields`` * ``order`` :param enable_search: Enable search query argument. :type enable_search: bool + :param enable_search_fields: Enable search fields argument, enable_search must also be True + :type enable_search_fields: bool + :param enable_search_operator: Enable search operator argument, enable_search must also be True + :type enable_search_operator: bool :param enable_order: Enable ordering via query argument. :type enable_order: bool """ enable_in_menu = kwargs.pop("enable_in_menu", False) enable_search = kwargs.pop("enable_search", True) + enable_search_fields = kwargs.pop("enable_search_fields", True) + enable_search_operator = kwargs.pop("enable_search_operator", True) enable_order = kwargs.pop("enable_order", True) required = kwargs.get("required", False) type_name = type_class if isinstance(type_class, str) else type_class.__name__ @@ -225,6 +272,22 @@ def PaginatedQuerySet(of_type, type_class, **kwargs): kwargs["search_query"] = graphene.Argument( graphene.String, description=_("Filter the results using Wagtail's search.") ) + if enable_search_operator: + kwargs["search_operator"] = graphene.Argument( + SearchOperatorEnum, + description=_( + "Specify search operator (and/or), see: https://docs.wagtail.org/en/stable/topics/search/searching.html#search-operator" + ), + default_value="and", + ) + + if enable_search_fields: + kwargs["search_fields"] = graphene.Argument( + graphene.List(graphene.String), + description=_( + "A comma-separated list of fields to search in. see: https://docs.wagtail.org/en/stable/topics/search/searching.html#specifying-the-fields-to-search" + ), + ) if "id" not in kwargs: kwargs["id"] = graphene.Argument(graphene.ID, description=_("Filter by ID")) diff --git a/grapple/utils.py b/grapple/utils.py index b901c11f..a40a92f8 100644 --- a/grapple/utils.py +++ b/grapple/utils.py @@ -8,6 +8,7 @@ from wagtail import VERSION as WAGTAIL_VERSION from wagtail.models import Site from wagtail.search.index import class_is_indexed +from wagtail.search.utils import parse_query_string from .settings import grapple_settings from .types.structures import BasePaginatedType, PaginationType @@ -101,6 +102,8 @@ def resolve_queryset( order=None, collection=None, in_menu=None, + search_operator="and", + search_fields=None, **kwargs, ): """ @@ -122,6 +125,11 @@ def resolve_queryset( :type order: str :param collection: Use Wagtail's collection id to filter images or documents :type collection: int + :param search_operator: The operator to use when combining search terms. + Defaults to "and". + :type search_operator: "and" | "or" + :param search_fields: A list of fields to search. Defaults to all fields. + :type search_fields: list """ qs = qs.all() if id is None else qs.filter(pk=id) @@ -152,7 +160,18 @@ def resolve_queryset( query = Query.get(search_query) query.add_hit() - qs = qs.search(search_query, order_by_relevance=order_by_relevance) + filters, parsed_query = parse_query_string(search_query, str(search_operator)) + + # check if search_fields is provided in the query string if it isn't provided as a graphQL argument + if search_fields is None: + search_fields = filters.getlist("fields", None) + + qs = qs.search( + parsed_query, + order_by_relevance=order_by_relevance, + operator=search_operator, + fields=search_fields, + ) if connection.vendor != "sqlite": qs = qs.annotate_score("search_score") @@ -183,9 +202,9 @@ def get_paginated_result(qs, page, per_page): count=len(page_obj.object_list), per_page=per_page, current_page=page_obj.number, - prev_page=page_obj.previous_page_number() - if page_obj.has_previous() - else None, + prev_page=( + page_obj.previous_page_number() if page_obj.has_previous() else None + ), next_page=page_obj.next_page_number() if page_obj.has_next() else None, total_pages=paginator.num_pages, ), @@ -193,7 +212,16 @@ def get_paginated_result(qs, page, per_page): def resolve_paginated_queryset( - qs, info, page=None, per_page=None, search_query=None, id=None, order=None, **kwargs + qs, + info, + page=None, + per_page=None, + id=None, + order=None, + search_query=None, + search_operator="and", + search_fields=None, + **kwargs, ): """ Add page, per_page and search capabilities to the query. This contains @@ -207,11 +235,16 @@ def resolve_paginated_queryset( :type id: int :param per_page: The maximum number of items to include on a page. :type per_page: int + :param order: Order the query set using the Django QuerySet order_by format. + :type order: str :param search_query: Using Wagtail search, exclude objects that do not match the search query. :type search_query: str - :param order: Order the query set using the Django QuerySet order_by format. - :type order: str + :param search_operator: The operator to use when combining search terms. + Defaults to "and". + :type search_operator: "and" | "or" + :param search_fields: A list of fields to search. Defaults to all fields. + :type search_fields: list """ page = int(page or 1) per_page = min( @@ -236,7 +269,18 @@ def resolve_paginated_queryset( query = Query.get(search_query) query.add_hit() - qs = qs.search(search_query, order_by_relevance=order_by_relevance) + filters, parsed_query = parse_query_string(search_query, search_operator) + + # check if search_fields is provided in the query string if it isn't provided as a graphQL argument + if search_fields is None: + search_fields = filters.getlist("fields", None) + + qs = qs.search( + parsed_query, + order_by_relevance=order_by_relevance, + operator=search_operator, + fields=search_fields, + ) if connection.vendor != "sqlite": qs = qs.annotate_score("search_score") diff --git a/tests/test_grapple.py b/tests/test_grapple.py index fe23abea..961cc78f 100644 --- a/tests/test_grapple.py +++ b/tests/test_grapple.py @@ -474,18 +474,41 @@ class PagesSearchTest(BaseGrappleTest): @classmethod def setUpTestData(cls): cls.home = HomePage.objects.first() - BlogPageFactory(title="Alpha", parent=cls.home, show_in_menus=True) - BlogPageFactory(title="Alpha Alpha", parent=cls.home) - BlogPageFactory(title="Alpha Beta", parent=cls.home) - BlogPageFactory(title="Alpha Gamma", parent=cls.home) - BlogPageFactory(title="Beta", parent=cls.home) - BlogPageFactory(title="Beta Alpha", parent=cls.home) - BlogPageFactory(title="Beta Beta", parent=cls.home) - BlogPageFactory(title="Beta Gamma", parent=cls.home) - BlogPageFactory(title="Gamma", parent=cls.home) - BlogPageFactory(title="Gamma Alpha", parent=cls.home) - BlogPageFactory(title="Gamma Beta", parent=cls.home) - BlogPageFactory(title="Gamma Gamma", parent=cls.home) + BlogPageFactory( + title="Alpha", + body=[("heading", "Sigma")], + parent=cls.home, + show_in_menus=True, + ) + BlogPageFactory( + title="Alpha Alpha", body=[("heading", "Sigma Sigma")], parent=cls.home + ) + BlogPageFactory( + title="Alpha Beta", body=[("heading", "Sigma Theta")], parent=cls.home + ) + BlogPageFactory( + title="Alpha Gamma", body=[("heading", "Sigma Delta")], parent=cls.home + ) + BlogPageFactory(title="Beta", body=[("heading", "Theta")], parent=cls.home) + BlogPageFactory( + title="Beta Alpha", body=[("heading", "Theta Sigma")], parent=cls.home + ) + BlogPageFactory( + title="Beta Beta", body=[("heading", "Theta Theta")], parent=cls.home + ) + BlogPageFactory( + title="Beta Gamma", body=[("heading", "Theta Delta")], parent=cls.home + ) + BlogPageFactory(title="Gamma", body=[("heading", "Delta")], parent=cls.home) + BlogPageFactory( + title="Gamma Alpha", body=[("heading", "Delta Sigma")], parent=cls.home + ) + BlogPageFactory( + title="Gamma Beta", body=[("heading", "Delta Theta")], parent=cls.home + ) + BlogPageFactory( + title="Gamma Gamma", body=[("heading", "Delta Delta")], parent=cls.home + ) @unittest.skipIf( connection.vendor != "sqlite", @@ -530,7 +553,6 @@ def test_searchQuery_order_by_relevance(self): } } """ - executed = self.client.execute(query, variables={"searchQuery": "Alpha"}) page_data = executed["data"].get("pages") self.assertEqual(len(page_data), 6) @@ -559,7 +581,6 @@ def test_explicit_order(self): query, variables={"searchQuery": "Gamma", "order": "-title"} ) page_data = executed["data"].get("pages") - self.assertEqual(len(page_data), 6) self.assertEqual(page_data[0]["title"], "Gamma Gamma") self.assertEqual(page_data[1]["title"], "Gamma Beta") @@ -593,6 +614,110 @@ def test_search_not_in_menus(self): page_data = executed["data"].get("pages") self.assertEqual(len(page_data), 12) # 11 blog pages + home page + def test_search_operator_default(self): + """default operator is and""" + query = """ + query($searchQuery: String) { + pages(searchQuery: $searchQuery) { + title + searchScore + } + } + """ + executed = self.client.execute(query, variables={"searchQuery": "Alpha Beta"}) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 2) + self.assertEqual(page_data[0]["title"], "Alpha Beta") + self.assertEqual(page_data[1]["title"], "Beta Alpha") + + def test_search_operator_and(self): + query = """ + query($searchQuery: String, $searchOperator: SearchOperatorEnum) { + pages(searchQuery: $searchQuery, searchOperator: $searchOperator) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, variables={"searchQuery": "Alpha Beta", "searchOperator": "AND"} + ) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 2) + self.assertEqual(page_data[0]["title"], "Alpha Beta") + self.assertEqual(page_data[1]["title"], "Beta Alpha") + + def test_search_operator_or(self): + query = """ + query($searchQuery: String, $searchOperator: SearchOperatorEnum) { + pages(searchQuery: $searchQuery, searchOperator: $searchOperator) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, variables={"searchQuery": "Alpha Beta", "searchOperator": "OR"} + ) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 10) + self.assertEqual(page_data[0]["title"], "Alpha") + self.assertEqual(page_data[1]["title"], "Alpha Alpha") + self.assertEqual(page_data[2]["title"], "Alpha Beta") + self.assertEqual(page_data[3]["title"], "Alpha Gamma") + self.assertEqual(page_data[4]["title"], "Beta") + self.assertEqual(page_data[5]["title"], "Beta Alpha") + self.assertEqual(page_data[6]["title"], "Beta Beta") + self.assertEqual(page_data[7]["title"], "Beta Gamma") + self.assertEqual(page_data[8]["title"], "Gamma Alpha") + self.assertEqual(page_data[9]["title"], "Gamma Beta") + + def test_search_fields_unset(self): + query = """ + query { + pages(searchQuery: "Sigma") { + title + } + } + """ + executed = self.client.execute(query) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 6) + self.assertEqual(page_data[0]["title"], "Alpha") + self.assertEqual(page_data[1]["title"], "Alpha Alpha") + self.assertEqual(page_data[2]["title"], "Alpha Beta") + self.assertEqual(page_data[3]["title"], "Alpha Gamma") + self.assertEqual(page_data[4]["title"], "Beta Alpha") + self.assertEqual(page_data[5]["title"], "Gamma Alpha") + + def test_search_fields_graphql_arg(self): + query = """ + query { + pages(searchQuery: "Sigma", searchFields: "title") { + title + } + } + """ + executed = self.client.execute(query) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 0) + + def test_search_fields_filter(self): + query = """ + query($searchQuery: String) { + pages(searchQuery: $searchQuery) { + title + searchScore + } + } + """ + executed = self.client.execute( + query, + variables={"searchQuery": "Sigma fields:title"}, + ) + page_data = executed["data"].get("pages") + self.assertEqual(len(page_data), 0) + class PageUrlPathTest(BaseGrappleTest): def _query_by_path(self, path, *, in_site=False): diff --git a/tests/testapp/models/core.py b/tests/testapp/models/core.py index e010333f..6b7d6041 100644 --- a/tests/testapp/models/core.py +++ b/tests/testapp/models/core.py @@ -14,6 +14,7 @@ ) from wagtail.fields import RichTextField, StreamField from wagtail.models import Orderable, Page +from wagtail.search import index from wagtail.snippets.models import register_snippet from wagtail_headless_preview.models import HeadlessPreviewMixin from wagtailmedia.edit_handlers import MediaChooserPanel @@ -163,6 +164,8 @@ def custom_property(self): "author": self.author.name if self.author else "Unknown", } + search_fields = Page.search_fields + [index.SearchField("body")] + graphql_fields = [ GraphQLString("date", required=True), GraphQLRichText("summary"),