From 1838811fd8abf59889d64ce841036b0c07b7e14d Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Wed, 25 Jun 2025 17:37:24 -0400 Subject: [PATCH 01/37] Create django_mongodb_backend.expressions package --- django_mongodb_backend/__init__.py | 2 +- django_mongodb_backend/expressions/__init__.py | 0 .../{expressions.py => expressions/builtins.py} | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 django_mongodb_backend/expressions/__init__.py rename django_mongodb_backend/{expressions.py => expressions/builtins.py} (99%) diff --git a/django_mongodb_backend/__init__.py b/django_mongodb_backend/__init__.py index 00700421a..d21566d9c 100644 --- a/django_mongodb_backend/__init__.py +++ b/django_mongodb_backend/__init__.py @@ -8,7 +8,7 @@ from .aggregates import register_aggregates # noqa: E402 from .checks import register_checks # noqa: E402 -from .expressions import register_expressions # noqa: E402 +from .expressions.builtins import register_expressions # noqa: E402 from .fields import register_fields # noqa: E402 from .functions import register_functions # noqa: E402 from .indexes import register_indexes # noqa: E402 diff --git a/django_mongodb_backend/expressions/__init__.py b/django_mongodb_backend/expressions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/django_mongodb_backend/expressions.py b/django_mongodb_backend/expressions/builtins.py similarity index 99% rename from django_mongodb_backend/expressions.py rename to django_mongodb_backend/expressions/builtins.py index 46eef56da..4f6575052 100644 --- a/django_mongodb_backend/expressions.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -25,7 +25,7 @@ ) from django.db.models.sql import Query -from .query_utils import process_lhs +from ..query_utils import process_lhs def case(self, compiler, connection): From a8b1c0339d1e9e14768bf460500d048e3f1a8d7e Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 16 Jun 2025 00:18:30 -0300 Subject: [PATCH 02/37] First approach. --- django_mongodb_backend/compiler.py | 99 ++++++++++++++++++++------- django_mongodb_backend/functions.py | 31 ++++++++- django_mongodb_backend/query.py | 3 + django_mongodb_backend/query_utils.py | 2 +- tests/queries_/models.py | 10 +++ 5 files changed, 118 insertions(+), 27 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 1c727039d..49221a5cc 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -17,6 +17,7 @@ from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING +from .functions import SearchScore from .query import MongoQuery, wrap_database_errors @@ -34,6 +35,8 @@ def __init__(self, *args, **kwargs): # A list of OrderBy objects for this query. self.order_by_objs = None self.subqueries = [] + # Atlas search calls + self.search_pipeline = [] def _get_group_alias_column(self, expr, annotation_group_idx): """Generate a dummy field for use in the ids fields in $group.""" @@ -57,6 +60,29 @@ def _get_column_from_expression(self, expr, alias): column_target.set_attributes_from_name(alias) return Col(self.collection_name, column_target) + def _get_replace_expr(self, sub_expr, group, alias): + column_target = sub_expr.output_field.clone() + column_target.db_column = alias + column_target.set_attributes_from_name(alias) + inner_column = Col(self.collection_name, column_target) + if getattr(sub_expr, "distinct", False): + # If the expression should return distinct values, use + # $addToSet to deduplicate. + rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) + group[alias] = {"$addToSet": rhs} + replacing_expr = sub_expr.copy() + replacing_expr.set_source_expressions([inner_column, None]) + else: + group[alias] = sub_expr.as_mql(self, self.connection) + replacing_expr = inner_column + # Count must return 0 rather than null. + if isinstance(sub_expr, Count): + replacing_expr = Coalesce(replacing_expr, 0) + # Variance = StdDev^2 + if isinstance(sub_expr, Variance): + replacing_expr = Power(replacing_expr, 2) + return replacing_expr + def _prepare_expressions_for_pipeline(self, expression, target, annotation_group_idx): """ Prepare expressions for the aggregation pipeline. @@ -80,29 +106,42 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group alias = ( f"__aggregation{next(annotation_group_idx)}" if sub_expr != expression else target ) - column_target = sub_expr.output_field.clone() - column_target.db_column = alias - column_target.set_attributes_from_name(alias) - inner_column = Col(self.collection_name, column_target) - if sub_expr.distinct: - # If the expression should return distinct values, use - # $addToSet to deduplicate. - rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) - group[alias] = {"$addToSet": rhs} - replacing_expr = sub_expr.copy() - replacing_expr.set_source_expressions([inner_column, None]) - else: - group[alias] = sub_expr.as_mql(self, self.connection) - replacing_expr = inner_column - # Count must return 0 rather than null. - if isinstance(sub_expr, Count): - replacing_expr = Coalesce(replacing_expr, 0) - # Variance = StdDev^2 - if isinstance(sub_expr, Variance): - replacing_expr = Power(replacing_expr, 2) - replacements[sub_expr] = replacing_expr + replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias) return replacements, group + def _prepare_search_expressions_for_pipeline(self, expression, target, search_idx): + searches = {} + replacements = {} + for sub_expr in self._get_search_expressions(expression): + alias = f"__search_expr.search{next(search_idx)}" + replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias) + return replacements, searches + + def _prepare_search_query_for_aggregation_pipeline(self, order_by): + replacements = {} + searches = {} + search_idx = itertools.count(start=1) + for target, expr in self.query.annotation_select.items(): + new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline( + expr, target, search_idx + ) + replacements.update(new_replacements) + searches.update(expr_searches) + + for expr, _ in order_by: + new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline( + expr, None, search_idx + ) + replacements.update(new_replacements) + searches.update(expr_searches) + + having_replacements, having_group = self._prepare_search_expressions_for_pipeline( + self.having, None, search_idx + ) + replacements.update(having_replacements) + searches.update(having_group) + return searches, replacements + def _prepare_annotations_for_aggregation_pipeline(self, order_by): """Prepare annotations for the aggregation pipeline.""" replacements = {} @@ -179,6 +218,9 @@ def _get_group_id_expressions(self, order_by): ids = self.get_project_fields(tuple(columns), force_expression=True) return ids, replacements + def _build_search_pipeline(self, search_queries): + pass + def _build_aggregation_pipeline(self, ids, group): """Build the aggregation pipeline for grouping.""" pipeline = [] @@ -209,7 +251,12 @@ def _build_aggregation_pipeline(self, ids, group): def pre_sql_setup(self, with_col_aliases=False): extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases) - group, all_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) + searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline( + order_by + ) + group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) + all_replacements = {**search_replacements, **group_replacements} + self.search_pipeline = searches # query.group_by is either: # - None: no GROUP BY # - True: group by select fields @@ -557,10 +604,16 @@ def get_lookup_pipeline(self): return result def _get_aggregate_expressions(self, expr): + return self._get_all_expressions_of_type(expr, Aggregate) + + def _get_search_expressions(self, expr): + return self._get_all_expressions_of_type(expr, SearchScore) + + def _get_all_expressions_of_type(self, expr, target_type): stack = [expr] while stack: expr = stack.pop() - if isinstance(expr, Aggregate): + if isinstance(expr, target_type): yield expr elif hasattr(expr, "get_source_expressions"): stack.extend(expr.get_source_expressions()) diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index 492316709..cc904fc12 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -2,8 +2,8 @@ from django.conf import settings from django.db import NotSupportedError -from django.db.models import DateField, DateTimeField, TimeField -from django.db.models.expressions import Func +from django.db.models import DateField, DateTimeField, Expression, FloatField, TimeField +from django.db.models.expressions import F, Func, Value from django.db.models.functions import JSONArray from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf from django.db.models.functions.datetime import ( @@ -38,8 +38,9 @@ Trim, Upper, ) +from django.utils.deconstruct import deconstructible -from .query_utils import process_lhs +from .query_utils import process_lhs, process_rhs MONGO_OPERATORS = { Ceil: "ceil", @@ -268,6 +269,30 @@ def trunc_time(self, compiler, connection): } +@deconstructible(path="django_mongodb_backend.functions.SearchScore") +class SearchScore(Expression): + def __init__(self, path, value, operation="equals", **kwargs): + self.extra_params = kwargs + self.lhs = path if hasattr(path, "resolve_expression") else F(path) + if not isinstance(value, str): + # TODO HANDLE VALUES LIKE Value("some string") + raise ValueError("STRING NEEDED") + self.rhs = Value(value) + self.operation = operation + super().__init__(output_field=FloatField()) + + def __repr__(self): + return f"search {self.field} = {self.value} | {self.extra_params}" + + def as_mql(self, compiler, connection): + lhs = process_lhs(self, compiler, connection) + rhs = process_rhs(self, compiler, connection) + return {"$search": {self.operation: {"path": lhs[:1], "query": rhs, **self.extra_params}}} + + def as_sql(self, compiler, connection): + return "", [] + + def register_functions(): Cast.as_mql = cast Concat.as_mql = concat diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index d59bc1631..e6290ead4 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -49,6 +49,7 @@ def __init__(self, compiler): self.lookup_pipeline = None self.project_fields = None self.aggregation_pipeline = compiler.aggregation_pipeline + self.search_pipeline = compiler.search_pipeline self.extra_fields = None self.combinator_pipeline = None # $lookup stage that encapsulates the pipeline for performing a nested @@ -81,6 +82,8 @@ def get_cursor(self): def get_pipeline(self): pipeline = [] + if self.search_pipeline: + pipeline.extend(self.search_pipeline) if self.lookup_pipeline: pipeline.extend(self.lookup_pipeline) for query in self.subqueries or (): diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 0bb292995..c03a0f7ab 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -4,7 +4,7 @@ def is_direct_value(node): - return not hasattr(node, "as_sql") + return not hasattr(node, "as_sql") and not hasattr(node, "as_mql") def process_lhs(node, compiler, connection): diff --git a/tests/queries_/models.py b/tests/queries_/models.py index 015102248..7bc21c540 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -1,6 +1,7 @@ from django.db import models from django_mongodb_backend.fields import ObjectIdAutoField, ObjectIdField +from django_mongodb_backend.indexes import SearchIndex class Author(models.Model): @@ -53,3 +54,12 @@ class Meta: def __str__(self): return str(self.pk) + + +class Article(models.Model): + headline = models.CharField(max_length=100) + number = models.IntegerField() + body = models.TextField() + + class Meta: + indexes = [SearchIndex(fields=["headline"])] From e8dce305dfb0269d24077adcd9dc98f0fa764fc1 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 22 Jun 2025 16:20:12 -0300 Subject: [PATCH 03/37] Add SearchExpressions --- .pre-commit-config.yaml | 2 +- django_mongodb_backend/compiler.py | 31 ++++-- django_mongodb_backend/functions.py | 165 +++++++++++++++++++++++++--- tests/queries_/test_search.py | 12 ++ 4 files changed, 180 insertions(+), 30 deletions(-) create mode 100644 tests/queries_/test_search.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a3301328..f19280a09 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,4 +81,4 @@ repos: rev: "v2.2.6" hooks: - id: codespell - args: ["-L", "nin"] + args: ["-L", "nin", "SearchIn", "searchin"] diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 49221a5cc..91da88659 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -17,7 +17,7 @@ from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING -from .functions import SearchScore +from .functions import SearchExpression from .query import MongoQuery, wrap_database_errors @@ -115,31 +115,31 @@ def _prepare_search_expressions_for_pipeline(self, expression, target, search_id for sub_expr in self._get_search_expressions(expression): alias = f"__search_expr.search{next(search_idx)}" replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias) - return replacements, searches + return replacements, list(searches.values()) def _prepare_search_query_for_aggregation_pipeline(self, order_by): replacements = {} - searches = {} - search_idx = itertools.count(start=1) + searches = [] + annotation_group_idx = itertools.count(start=1) for target, expr in self.query.annotation_select.items(): new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline( - expr, target, search_idx + expr, target, annotation_group_idx ) replacements.update(new_replacements) - searches.update(expr_searches) + searches += expr_searches for expr, _ in order_by: new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline( - expr, None, search_idx + expr, None, annotation_group_idx ) replacements.update(new_replacements) - searches.update(expr_searches) + searches += expr_searches having_replacements, having_group = self._prepare_search_expressions_for_pipeline( - self.having, None, search_idx + self.having, None, annotation_group_idx ) replacements.update(having_replacements) - searches.update(having_group) + searches += having_group return searches, replacements def _prepare_annotations_for_aggregation_pipeline(self, order_by): @@ -249,6 +249,13 @@ def _build_aggregation_pipeline(self, ids, group): pipeline.append({"$unset": "_id"}) return pipeline + def _compound_searches_queries(self, searches): + if not searches: + return [] + if len(searches) > 1: + raise ValueError("Cannot perform more than one search operation.") + return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": "searchScore"}}}] + def pre_sql_setup(self, with_col_aliases=False): extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases) searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline( @@ -256,7 +263,7 @@ def pre_sql_setup(self, with_col_aliases=False): ) group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) all_replacements = {**search_replacements, **group_replacements} - self.search_pipeline = searches + self.search_pipeline = self._compound_searches_queries(searches) # query.group_by is either: # - None: no GROUP BY # - True: group by select fields @@ -607,7 +614,7 @@ def _get_aggregate_expressions(self, expr): return self._get_all_expressions_of_type(expr, Aggregate) def _get_search_expressions(self, expr): - return self._get_all_expressions_of_type(expr, SearchScore) + return self._get_all_expressions_of_type(expr, SearchExpression) def _get_all_expressions_of_type(self, expr, target_type): stack = [expr] diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index cc904fc12..fd528aa1a 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -3,7 +3,7 @@ from django.conf import settings from django.db import NotSupportedError from django.db.models import DateField, DateTimeField, Expression, FloatField, TimeField -from django.db.models.expressions import F, Func, Value +from django.db.models.expressions import Func from django.db.models.functions import JSONArray from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf from django.db.models.functions.datetime import ( @@ -38,9 +38,8 @@ Trim, Upper, ) -from django.utils.deconstruct import deconstructible -from .query_utils import process_lhs, process_rhs +from .query_utils import process_lhs MONGO_OPERATORS = { Ceil: "ceil", @@ -269,28 +268,160 @@ def trunc_time(self, compiler, connection): } -@deconstructible(path="django_mongodb_backend.functions.SearchScore") -class SearchScore(Expression): - def __init__(self, path, value, operation="equals", **kwargs): - self.extra_params = kwargs - self.lhs = path if hasattr(path, "resolve_expression") else F(path) - if not isinstance(value, str): - # TODO HANDLE VALUES LIKE Value("some string") - raise ValueError("STRING NEEDED") - self.rhs = Value(value) - self.operation = operation +class SearchExpression(Expression): + optional_arguments = [] + + def __init__(self, *args, score=None, **kwargs): + self.score = score + # Support positional arguments first + if args and len(args) > len(self.expected_arguments) + len(self.optional_arguments): + raise ValueError( + f"Too many positional arguments: expected {len(self.expected_arguments)}" + ) + # TODO: REFACTOR. + for arg_name, arg_type in self.expected_arguments: + if args: + value = args.pop(0) + if arg_name in kwargs: + raise ValueError( + f"Argument '{arg_name}' was provided both positionally and as keyword" + ) + elif arg_name in kwargs: + value = kwargs.pop(arg_name) + else: + raise ValueError(f"Missing required argument '{arg_name}'") + if not isinstance(value, arg_type): + raise ValueError(f"Argument '{arg_name}' must be of type {arg_type.__name__}") + setattr(self, arg_name, value) + + for arg_name, arg_type in self.optional_arguments: + if args: + if arg_name in kwargs: + raise ValueError( + f"Argument '{arg_name}' was provided both positionally and as keyword" + ) + value = args.pop(0) + elif arg_name in kwargs: + value = kwargs.pop(arg_name) + else: + value = None + if value is not None and not isinstance(value, arg_type): + raise ValueError(f"Argument '{arg_name}' must be of type {arg_type.__name__}") + setattr(self, arg_name, value) + if kwargs: + raise ValueError(f"Unexpected keyword arguments: {list(kwargs.keys())}") super().__init__(output_field=FloatField()) + def get_source_expressions(self): + return [] + + def __str__(self): + args = ", ".join(map(str, self.get_source_expressions())) + return f"{self.search_type}({args})" + def __repr__(self): - return f"search {self.field} = {self.value} | {self.extra_params}" + return str(self) + + def as_sql(self, compiler, connection): + return "", [] + + def _get_query_index(self, field, compiler): + for search_indexes in compiler.collection.list_search_indexes(): + mappings = search_indexes["latestDefinition"]["mappings"] + if mappings["dynamic"] or field in mappings["fields"]: + return search_indexes["name"] + return "default" def as_mql(self, compiler, connection): - lhs = process_lhs(self, compiler, connection) - rhs = process_rhs(self, compiler, connection) - return {"$search": {self.operation: {"path": lhs[:1], "query": rhs, **self.extra_params}}} + params = {} + for arg_name, _ in self.expected_arguments: + params[arg_name] = getattr(self, arg_name) + if self.score: + params["score"] = self.score.as_mql(compiler, connection) + index = self._get_query_index(params.get("path"), compiler) + return {"$search": {self.search_type: params, "index": index}} + + +class SearchAutocomplete(SearchExpression): + search_type = "autocomplete" + expected_arguments = [("path", str), ("query", str)] + + +class SearchEquals(SearchExpression): + search_type = "equals" + expected_arguments = [("path", str), ("value", str)] + + +class SearchExists(SearchExpression): + search_type = "equals" + expected_arguments = [("path", str)] + + +class SearchIn(SearchExpression): + search_type = "equals" + expected_arguments = [("path", str), ("value", str | list)] + + +class SearchPhrase(SearchExpression): + search_type = "equals" + expected_arguments = [("path", str), ("value", str | list)] + optional_arguments = [("slop", int), ("synonyms", str)] + + +""" +IT IS BEING REFACTORED +class SearchOperator(SearchExpression): + _operation_params = { + "autocomplete": ("path", {"query"}), + "equals": ("path", {"value"}), + "exists": ("path", {}), + "in": ("path", {"value"}), + "phrase": ("path", {"query"}), + "queryString": ("defaultPath", {"query"}), + "range": ("path", {("lt", "lte"), ("gt", "gte")}), + "regex": ("path", {"query"}), + "text": ("path", {"query"}), + "wildcard": ("path", {"query"}), + "geoShape": ("path", {"query", "relation", "geometry"}), + "geoWithin": ("path", {("box", "circle", "geometry")}), + "moreLikeThis": (None, {"like"}), + "near": ("path", {"origin", "pivot"}), + } + + def __init__(self, operation, **kwargs): + self.lhs = path if path is None or hasattr(path, "resolve_expression") else F(path) + self.operation = operation + self.lhs_field, needed_params = self._operation_params[self.operation] + rhs_values = {} + for param in needed_params: + if isinstance(param, str): + rhs_values[param] = kwargs.pop(param) + else: + for key in param: + if key in kwargs: + rhs_values[param] = kwargs.pop(key) + break + else: + raise ValueError(f"Not found either {', '.join(param)}") + + self.rhs_values = rhs_values + self.extra_params = kwargs + super().__init__(output_field=FloatField()) + + def as_mql(self, compiler, connection): + params = {**self.rhs_values, **self.extra_params} + if self.lhs: + lhs_mql = process_lhs(self, compiler, connection) + params[self.lhs_field] = lhs_mql[1:] + index = self._get_query_index(compiler, connection) + return {"$search": {self.operation: params, "index": index}} + + def get_source_expressions(self): + return [self.lhs, self.rhs, self.extra_params] def as_sql(self, compiler, connection): return "", [] +""" def register_functions(): diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py new file mode 100644 index 000000000..ef6502958 --- /dev/null +++ b/tests/queries_/test_search.py @@ -0,0 +1,12 @@ +from django.test import TestCase + +from .models import Article + +from django_mongodb_backend.functions import SearchEquals + + +class SearchTests(TestCase): + def test_1(self): + Article.objects.create(headline="cross", number=1, body="body") + aa = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")).all() + print(aa) From 7c40032e7dbf8a295e284041fe6b269fc8f4d125 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 24 Jun 2025 11:23:25 -0300 Subject: [PATCH 04/37] Add test --- tests/queries_/test_search.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index ef6502958..c6ec37a9c 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -1,12 +1,13 @@ from django.test import TestCase -from .models import Article - from django_mongodb_backend.functions import SearchEquals +from .models import Article + class SearchTests(TestCase): def test_1(self): Article.objects.create(headline="cross", number=1, body="body") aa = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")).all() - print(aa) + self.assertEqual(aa.score == 1) + # print(aa) From 6be0799377e7123027a61c3a080b60f89af8ba0a Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Wed, 25 Jun 2025 23:54:53 -0300 Subject: [PATCH 05/37] Refactor. --- .pre-commit-config.yaml | 2 +- .../expressions/builtins.py | 288 ++++++++++++++++++ django_mongodb_backend/functions.py | 158 +--------- 3 files changed, 290 insertions(+), 158 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f19280a09..188c8f3cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -81,4 +81,4 @@ repos: rev: "v2.2.6" hooks: - id: codespell - args: ["-L", "nin", "SearchIn", "searchin"] + args: ["-L", "nin", "-L", "searchin"] diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 4f6575052..64ce7cb61 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -5,6 +5,7 @@ from bson import Decimal128 from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import NotSupportedError +from django.db.models import Expression, FloatField from django.db.models.expressions import ( Case, Col, @@ -207,6 +208,293 @@ def value(self, compiler, connection): # noqa: ARG001 return value +class SearchExpression(Expression): + def __init__(self): + super().__init__(output_field=FloatField()) + + def get_source_expressions(self): + return [] + + def __str__(self): + args = ", ".join(map(str, self.get_source_expressions())) + return f"{self.search_type}({args})" + + def __repr__(self): + return str(self) + + def as_sql(self, compiler, connection): + return "", [] + + def _get_query_index(self, fields, compiler): + fields = set(fields) + for search_indexes in compiler.collection.list_search_indexes(): + mappings = search_indexes["latestDefinition"]["mappings"] + if mappings["dynamic"] or fields.issubset(set(mappings["fields"])): + return search_indexes["name"] + return "default" + + +class SearchAutocomplete(SearchExpression): + def __init__(self, path, query, score=None): + self.path = F(path) + self.query = Value(query) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + "query": self.query.as_mql(compiler, connection), + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"autocomplete": params, "index": index}} + + +class SearchEquals(SearchExpression): + def __init__(self, path, value, score=None): + self.path = F(path) + self.value = Value(query) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + "value": self.value.as_mql(compiler, connection), + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"equals": params, "index": index}} + + +class SearchExists(SearchExpression): + def __init__(self, path, score=None): + self.path = F(path) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"exists": params, "index": index}} + + +class SearchIn(SearchExpression): + def __init__(self, path, value, score=None): + self.path = F(path) + self.value = Value(value) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + "value": self.value.as_mql(compiler, connection), + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"in": params, "index": index}} + + +class SearchPhrase(SearchExpression): + def __init__(self, path, value, slop=None, synonyms=None, score=None): + self.path = F(path) + self.value = Value(value) + self.score = score + self.slop = slop + self.synonyms = synonyms + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + "value": self.value.as_mql(compiler, connection), + } + if self.score is not None: + params["score"] = self.score + if self.slop is not None: + params["slop"] = self.slop + if self.synonyms is not None: + params["synonyms"] = self.synonyms + index = self._get_query_index([self.path], compiler) + return {"$search": {"phrase": params, "index": index}} + + +class SearchQueryString(SearchExpression): + def __init__(self, path, query, score=None): + self.path = F(path) + self.query = Value(query) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "defaultPath": self.path.as_mql(compiler, connection)[1:], + "query": self.query.as_mql(compiler, connection), + } + if self.score is not None: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"queryString": params, "index": index}} + + +class SearchRange(SearchExpression): + def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): + self.path = F(path) + self.lt = Value(lt) + self.lte = Value(lte) + self.gt = Value(gt) + self.gte = Value(gte) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + } + if self.score is not None: + params["score"] = self.score + if self.lt is not None: + params["lt"] = self.lt.as_mql(compiler, connection) + if self.lte is not None: + params["lte"] = self.lte.as_mql(compiler, connection) + if self.gt is not None: + params["gt"] = self.gt.as_mql(compiler, connection) + if self.gte is not None: + params["gte"] = self.gte.as_mql(compiler, connection) + index = self._get_query_index([self.path], compiler) + return {"$search": {"range": params, "index": index}} + + +class SearchRegex(SearchExpression): + def __init__(self, path, query, allow_analyzed_field=None, score=None): + self.path = F(path) + self.allow_analyzed_field = Value(allow_analyzed_field) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + } + if self.score: + params["score"] = self.score + if self.allow_analyzed_field is not None: + params["allowAnalyzedField"] = self.allow_analyzed_field.as_mql(compiler, connection) + index = self._get_query_index([self.path], compiler) + return {"$search": {"regex": params, "index": index}} + + +class SearchText(SearchExpression): + def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): + self.path = F(path) + self.fuzzy = Value(fuzzy) + self.match_criteria = Value(match_criteria) + self.synonyms = Value(synonyms) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + } + if self.score: + params["score"] = self.score + if self.fuzzy is not None: + params["fuzzy"] = self.fuzzy.as_mql(compiler, connection) + if self.match_criteria is not None: + params["matchCriteria"] = self.match_criteria.as_mql(compiler, connection) + if self.synonyms is not None: + params["synonyms"] = self.synonyms.as_mql(compiler, connection) + index = self._get_query_index([self.path], compiler) + return {"$search": {"text": params, "index": index}} + + +class SearchWildcard(SearchExpression): + def __init__(self, path, query, allow_analyzed_field=None, score=None): + self.path = F(path) + self.allow_analyzed_field = Value(allow_analyzed_field) + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + } + if self.score: + params["score"] = self.score + if self.allow_analyzed_field is not None: + params["allowAnalyzedField"] = self.allow_analyzed_field.as_mql(compiler, connection) + index = self._get_query_index([self.path], compiler) + return {"$search": {"wildcard": params, "index": index}} + + +class SearchGeoShape(SearchExpression): + def __init__(self, path, relation, geometry, score=None): + self.path = F(path) + self.relation = relation + self.geometry = geometry + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + "relation": self.relation, + "geometry": self.geometry, + } + if self.score: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"wildcard": params, "index": index}} + + +class SearchGeoWithin(SearchExpression): + def __init__(self, path, kind, geo_object, geometry, score=None): + self.path = F(path) + self.kind = kind + self.geo_object = geo_object + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "path": self.path.as_mql(compiler, connection)[1:], + self.kind: self.geo_object, + } + if self.score: + params["score"] = self.score + index = self._get_query_index([self.path], compiler) + return {"$search": {"wildcard": params, "index": index}} + + +class SearchMoreLikeThis(SearchExpression): + def __init__(self, documents, score=None): + self.documents = documents + self.score = score + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "like": self.documents, + } + if self.score: + params["score"] = self.score + needed_fields = [] + for doc in self.documents: + needed_fields += list(doc.keys()) + index = self._get_query_index(needed_fields, compiler) + return {"$search": {"wildcard": params, "index": index}} + + def register_expressions(): Case.as_mql = case Col.as_mql = col diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index fd528aa1a..492316709 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -2,7 +2,7 @@ from django.conf import settings from django.db import NotSupportedError -from django.db.models import DateField, DateTimeField, Expression, FloatField, TimeField +from django.db.models import DateField, DateTimeField, TimeField from django.db.models.expressions import Func from django.db.models.functions import JSONArray from django.db.models.functions.comparison import Cast, Coalesce, Greatest, Least, NullIf @@ -268,162 +268,6 @@ def trunc_time(self, compiler, connection): } -class SearchExpression(Expression): - optional_arguments = [] - - def __init__(self, *args, score=None, **kwargs): - self.score = score - # Support positional arguments first - if args and len(args) > len(self.expected_arguments) + len(self.optional_arguments): - raise ValueError( - f"Too many positional arguments: expected {len(self.expected_arguments)}" - ) - # TODO: REFACTOR. - for arg_name, arg_type in self.expected_arguments: - if args: - value = args.pop(0) - if arg_name in kwargs: - raise ValueError( - f"Argument '{arg_name}' was provided both positionally and as keyword" - ) - elif arg_name in kwargs: - value = kwargs.pop(arg_name) - else: - raise ValueError(f"Missing required argument '{arg_name}'") - if not isinstance(value, arg_type): - raise ValueError(f"Argument '{arg_name}' must be of type {arg_type.__name__}") - setattr(self, arg_name, value) - - for arg_name, arg_type in self.optional_arguments: - if args: - if arg_name in kwargs: - raise ValueError( - f"Argument '{arg_name}' was provided both positionally and as keyword" - ) - value = args.pop(0) - elif arg_name in kwargs: - value = kwargs.pop(arg_name) - else: - value = None - if value is not None and not isinstance(value, arg_type): - raise ValueError(f"Argument '{arg_name}' must be of type {arg_type.__name__}") - setattr(self, arg_name, value) - if kwargs: - raise ValueError(f"Unexpected keyword arguments: {list(kwargs.keys())}") - super().__init__(output_field=FloatField()) - - def get_source_expressions(self): - return [] - - def __str__(self): - args = ", ".join(map(str, self.get_source_expressions())) - return f"{self.search_type}({args})" - - def __repr__(self): - return str(self) - - def as_sql(self, compiler, connection): - return "", [] - - def _get_query_index(self, field, compiler): - for search_indexes in compiler.collection.list_search_indexes(): - mappings = search_indexes["latestDefinition"]["mappings"] - if mappings["dynamic"] or field in mappings["fields"]: - return search_indexes["name"] - return "default" - - def as_mql(self, compiler, connection): - params = {} - for arg_name, _ in self.expected_arguments: - params[arg_name] = getattr(self, arg_name) - if self.score: - params["score"] = self.score.as_mql(compiler, connection) - index = self._get_query_index(params.get("path"), compiler) - return {"$search": {self.search_type: params, "index": index}} - - -class SearchAutocomplete(SearchExpression): - search_type = "autocomplete" - expected_arguments = [("path", str), ("query", str)] - - -class SearchEquals(SearchExpression): - search_type = "equals" - expected_arguments = [("path", str), ("value", str)] - - -class SearchExists(SearchExpression): - search_type = "equals" - expected_arguments = [("path", str)] - - -class SearchIn(SearchExpression): - search_type = "equals" - expected_arguments = [("path", str), ("value", str | list)] - - -class SearchPhrase(SearchExpression): - search_type = "equals" - expected_arguments = [("path", str), ("value", str | list)] - optional_arguments = [("slop", int), ("synonyms", str)] - - -""" -IT IS BEING REFACTORED -class SearchOperator(SearchExpression): - _operation_params = { - "autocomplete": ("path", {"query"}), - "equals": ("path", {"value"}), - "exists": ("path", {}), - "in": ("path", {"value"}), - "phrase": ("path", {"query"}), - "queryString": ("defaultPath", {"query"}), - "range": ("path", {("lt", "lte"), ("gt", "gte")}), - "regex": ("path", {"query"}), - "text": ("path", {"query"}), - "wildcard": ("path", {"query"}), - "geoShape": ("path", {"query", "relation", "geometry"}), - "geoWithin": ("path", {("box", "circle", "geometry")}), - "moreLikeThis": (None, {"like"}), - "near": ("path", {"origin", "pivot"}), - } - - def __init__(self, operation, **kwargs): - self.lhs = path if path is None or hasattr(path, "resolve_expression") else F(path) - self.operation = operation - self.lhs_field, needed_params = self._operation_params[self.operation] - rhs_values = {} - for param in needed_params: - if isinstance(param, str): - rhs_values[param] = kwargs.pop(param) - else: - for key in param: - if key in kwargs: - rhs_values[param] = kwargs.pop(key) - break - else: - raise ValueError(f"Not found either {', '.join(param)}") - - self.rhs_values = rhs_values - self.extra_params = kwargs - super().__init__(output_field=FloatField()) - - def as_mql(self, compiler, connection): - params = {**self.rhs_values, **self.extra_params} - if self.lhs: - lhs_mql = process_lhs(self, compiler, connection) - params[self.lhs_field] = lhs_mql[1:] - index = self._get_query_index(compiler, connection) - return {"$search": {self.operation: params, "index": index}} - - def get_source_expressions(self): - return [self.lhs, self.rhs, self.extra_params] - - def as_sql(self, compiler, connection): - return "", [] -""" - - def register_functions(): Cast.as_mql = cast Concat.as_mql = concat From 3057c017057eedb5b3d09d5eae9ad98714cc78a7 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 29 Jun 2025 15:25:37 -0300 Subject: [PATCH 06/37] Add search index test --- django_mongodb_backend/compiler.py | 2 +- .../expressions/builtins.py | 143 +++++----- tests/queries_/models.py | 5 +- tests/queries_/test_search.py | 255 +++++++++++++++++- 4 files changed, 334 insertions(+), 71 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 91da88659..41f278ea0 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -17,7 +17,7 @@ from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING -from .functions import SearchExpression +from .expressions.builtins import SearchExpression from .query import MongoQuery, wrap_database_errors diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 64ce7cb61..60dc1bd61 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -209,8 +209,7 @@ def value(self, compiler, connection): # noqa: ARG001 class SearchExpression(Expression): - def __init__(self): - super().__init__(output_field=FloatField()) + output_field = FloatField() def get_source_expressions(self): return [] @@ -235,34 +234,37 @@ def _get_query_index(self, fields, compiler): class SearchAutocomplete(SearchExpression): - def __init__(self, path, query, score=None): - self.path = F(path) - self.query = Value(query) + def __init__(self, path, query, fuzzy=None, score=None): + self.path = path + self.query = query + self.fuzzy = fuzzy self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], - "query": self.query.as_mql(compiler, connection), + "path": self.path, + "query": self.query, } if self.score is not None: params["score"] = self.score + if self.fuzzy is not None: + params["fuzzy"] = self.fuzzy index = self._get_query_index([self.path], compiler) return {"$search": {"autocomplete": params, "index": index}} class SearchEquals(SearchExpression): def __init__(self, path, value, score=None): - self.path = F(path) - self.value = Value(query) + self.path = path + self.value = value self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], - "value": self.value.as_mql(compiler, connection), + "path": self.path, + "value": self.value, } if self.score is not None: params["score"] = self.score @@ -272,13 +274,13 @@ def as_mql(self, compiler, connection): class SearchExists(SearchExpression): def __init__(self, path, score=None): - self.path = F(path) + self.path = path self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], + "path": self.path, } if self.score is not None: params["score"] = self.score @@ -288,15 +290,15 @@ def as_mql(self, compiler, connection): class SearchIn(SearchExpression): def __init__(self, path, value, score=None): - self.path = F(path) - self.value = Value(value) + self.path = path + self.value = value self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], - "value": self.value.as_mql(compiler, connection), + "path": self.path, + "value": self.value, } if self.score is not None: params["score"] = self.score @@ -305,9 +307,9 @@ def as_mql(self, compiler, connection): class SearchPhrase(SearchExpression): - def __init__(self, path, value, slop=None, synonyms=None, score=None): - self.path = F(path) - self.value = Value(value) + def __init__(self, path, query, slop=None, synonyms=None, score=None): + self.path = path + self.query = query self.score = score self.slop = slop self.synonyms = synonyms @@ -315,8 +317,8 @@ def __init__(self, path, value, slop=None, synonyms=None, score=None): def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], - "value": self.value.as_mql(compiler, connection), + "path": self.path, + "query": self.query, } if self.score is not None: params["score"] = self.score @@ -330,15 +332,15 @@ def as_mql(self, compiler, connection): class SearchQueryString(SearchExpression): def __init__(self, path, query, score=None): - self.path = F(path) - self.query = Value(query) + self.path = path + self.query = query self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "defaultPath": self.path.as_mql(compiler, connection)[1:], - "query": self.query.as_mql(compiler, connection), + "defaultPath": self.path, + "query": self.query, } if self.score is not None: params["score"] = self.score @@ -348,98 +350,104 @@ def as_mql(self, compiler, connection): class SearchRange(SearchExpression): def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): - self.path = F(path) - self.lt = Value(lt) - self.lte = Value(lte) - self.gt = Value(gt) - self.gte = Value(gte) + self.path = path + self.lt = lt + self.lte = lte + self.gt = gt + self.gte = gte self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], + "path": self.path, } if self.score is not None: params["score"] = self.score if self.lt is not None: - params["lt"] = self.lt.as_mql(compiler, connection) + params["lt"] = self.lt if self.lte is not None: - params["lte"] = self.lte.as_mql(compiler, connection) + params["lte"] = self.lte if self.gt is not None: - params["gt"] = self.gt.as_mql(compiler, connection) + params["gt"] = self.gt if self.gte is not None: - params["gte"] = self.gte.as_mql(compiler, connection) + params["gte"] = self.gte index = self._get_query_index([self.path], compiler) return {"$search": {"range": params, "index": index}} class SearchRegex(SearchExpression): def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = F(path) - self.allow_analyzed_field = Value(allow_analyzed_field) + self.path = path + self.query = query + self.allow_analyzed_field = allow_analyzed_field self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], + "path": self.path, + "query": self.query, } if self.score: params["score"] = self.score if self.allow_analyzed_field is not None: - params["allowAnalyzedField"] = self.allow_analyzed_field.as_mql(compiler, connection) + params["allowAnalyzedField"] = self.allow_analyzed_field index = self._get_query_index([self.path], compiler) return {"$search": {"regex": params, "index": index}} class SearchText(SearchExpression): def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): - self.path = F(path) - self.fuzzy = Value(fuzzy) - self.match_criteria = Value(match_criteria) - self.synonyms = Value(synonyms) + self.path = path + self.query = query + self.fuzzy = fuzzy + self.match_criteria = match_criteria + self.synonyms = synonyms self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], + "path": self.path, + "query": self.query, } if self.score: params["score"] = self.score if self.fuzzy is not None: - params["fuzzy"] = self.fuzzy.as_mql(compiler, connection) + params["fuzzy"] = self.fuzzy if self.match_criteria is not None: - params["matchCriteria"] = self.match_criteria.as_mql(compiler, connection) + params["matchCriteria"] = self.match_criteria if self.synonyms is not None: - params["synonyms"] = self.synonyms.as_mql(compiler, connection) + params["synonyms"] = self.synonyms index = self._get_query_index([self.path], compiler) return {"$search": {"text": params, "index": index}} class SearchWildcard(SearchExpression): def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = F(path) - self.allow_analyzed_field = Value(allow_analyzed_field) + self.path = path + self.query = query + self.allow_analyzed_field = allow_analyzed_field self.score = score super().__init__() def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], + "path": self.path, + "query": self.query, } if self.score: params["score"] = self.score if self.allow_analyzed_field is not None: - params["allowAnalyzedField"] = self.allow_analyzed_field.as_mql(compiler, connection) + params["allowAnalyzedField"] = self.allow_analyzed_field index = self._get_query_index([self.path], compiler) return {"$search": {"wildcard": params, "index": index}} class SearchGeoShape(SearchExpression): def __init__(self, path, relation, geometry, score=None): - self.path = F(path) + self.path = path self.relation = relation self.geometry = geometry self.score = score @@ -447,19 +455,19 @@ def __init__(self, path, relation, geometry, score=None): def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], + "path": self.path, "relation": self.relation, "geometry": self.geometry, } if self.score: params["score"] = self.score index = self._get_query_index([self.path], compiler) - return {"$search": {"wildcard": params, "index": index}} + return {"$search": {"geoShape": params, "index": index}} class SearchGeoWithin(SearchExpression): - def __init__(self, path, kind, geo_object, geometry, score=None): - self.path = F(path) + def __init__(self, path, kind, geo_object, score=None): + self.path = path self.kind = kind self.geo_object = geo_object self.score = score @@ -467,13 +475,13 @@ def __init__(self, path, kind, geo_object, geometry, score=None): def as_mql(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection)[1:], + "path": self.path, self.kind: self.geo_object, } if self.score: params["score"] = self.score index = self._get_query_index([self.path], compiler) - return {"$search": {"wildcard": params, "index": index}} + return {"$search": {"geoWithin": params, "index": index}} class SearchMoreLikeThis(SearchExpression): @@ -492,7 +500,22 @@ def as_mql(self, compiler, connection): for doc in self.documents: needed_fields += list(doc.keys()) index = self._get_query_index(needed_fields, compiler) - return {"$search": {"wildcard": params, "index": index}} + return {"$search": {"moreLikeThis": params, "index": index}} + + +class SearchScoreOption: + """Class to mutate scoring on a search operation""" + + def __init__(self, definitions=None): + self.definitions = definitions + + +class CombinedSearchExpression(SearchExpression): + def __init__(self, lhs, connector, rhs, output_field=None): + super().__init__(output_field=output_field) + self.connector = connector + self.lhs = lhs + self.rhs = rhs def register_expressions(): diff --git a/tests/queries_/models.py b/tests/queries_/models.py index 7bc21c540..fd70b395a 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -1,7 +1,6 @@ from django.db import models from django_mongodb_backend.fields import ObjectIdAutoField, ObjectIdField -from django_mongodb_backend.indexes import SearchIndex class Author(models.Model): @@ -60,6 +59,4 @@ class Article(models.Model): headline = models.CharField(max_length=100) number = models.IntegerField() body = models.TextField() - - class Meta: - indexes = [SearchIndex(fields=["headline"])] + location = models.JSONField(null=True) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index c6ec37a9c..b8d3092f0 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -1,13 +1,256 @@ +import time + +from django.db import connection from django.test import TestCase +from pymongo.operations import SearchIndexModel -from django_mongodb_backend.functions import SearchEquals +from django_mongodb_backend.expressions.builtins import ( + SearchAutocomplete, + SearchEquals, + SearchExists, + SearchGeoShape, + SearchGeoWithin, + SearchIn, + SearchPhrase, + SearchRange, + SearchRegex, + SearchText, + SearchWildcard, +) from .models import Article -class SearchTests(TestCase): - def test_1(self): +class CreateIndexMixin: + def _get_collection(self, model): + return connection.database.get_collection(model._meta.db_table) + + def create_search_index(self, model, index_name, definition): + collection = self._get_collection(model) + idx = SearchIndexModel(definition=definition, name="test_index") + collection.create_search_index(idx) + + +class SearchEqualsTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "equals_headline_index", + {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, + ) Article.objects.create(headline="cross", number=1, body="body") - aa = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")).all() - self.assertEqual(aa.score == 1) - # print(aa) + time.sleep(1) + + def test_search_equals(self): + qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")) + self.assertEqual(qs.first().headline, "cross") + + +class SearchAutocompleteTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "autocomplete_headline_index", + { + "mappings": { + "dynamic": False, + "fields": { + "headline": { + "type": "autocomplete", + "analyzer": "lucene.standard", + "tokenization": "edgeGram", + "minGrams": 3, + "maxGrams": 5, + "foldDiacritics": False, + } + }, + } + }, + ) + Article.objects.create(headline="crossing and something", number=2, body="river") + + def test_search_autocomplete(self): + qs = Article.objects.annotate(score=SearchAutocomplete(path="headline", query="crossing")) + self.assertEqual(qs.first().headline, "crossing and something") + + +class SearchExistsTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "exists_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "token"}}}}, + ) + Article.objects.create(headline="ignored", number=3, body="something") + + def test_search_exists(self): + qs = Article.objects.annotate(score=SearchExists(path="body")) + self.assertEqual(qs.count(), 1) + self.assertEqual(qs.first().body, "something") + + +class SearchInTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "in_headline_index", + {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, + ) + Article.objects.create(headline="cross", number=1, body="a") + Article.objects.create(headline="road", number=2, body="b") + time.sleep(1) + + def test_search_in(self): + qs = Article.objects.annotate(score=SearchIn(path="headline", value=["cross", "river"])) + self.assertEqual(qs.first().headline, "cross") + + +class SearchPhraseTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "phrase_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, + ) + Article.objects.create(headline="irrelevant", number=1, body="the quick brown fox") + time.sleep(1) + + def test_search_phrase(self): + qs = Article.objects.annotate(score=SearchPhrase(path="body", query="quick brown")) + self.assertIn("quick brown", qs.first().body) + + +class SearchRangeTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "range_number_index", + {"mappings": {"dynamic": False, "fields": {"number": {"type": "number"}}}}, + ) + Article.objects.create(headline="x", number=5, body="z") + Article.objects.create(headline="y", number=20, body="z") + time.sleep(1) + + def test_search_range(self): + qs = Article.objects.annotate(score=SearchRange(path="number", gte=10, lt=30)) + self.assertEqual(qs.first().number, 20) + + +class SearchRegexTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "regex_headline_index", + { + "mappings": { + "dynamic": False, + "fields": {"headline": {"type": "string", "analyzer": "lucene.keyword"}}, + } + }, + ) + Article.objects.create(headline="hello world", number=1, body="abc") + time.sleep(1) + + def test_search_regex(self): + qs = Article.objects.annotate( + score=SearchRegex(path="headline", query="hello.*", allow_analyzed_field=False) + ) + self.assertTrue(qs.first().headline.startswith("hello")) + + +class SearchTextTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "text_body_index", + {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, + ) + Article.objects.create(headline="ignored", number=1, body="The lazy dog sleeps") + time.sleep(1) + + def test_search_text(self): + qs = Article.objects.annotate(score=SearchText(path="body", query="lazy")) + self.assertIn("lazy", qs.first().body) + + def test_search_text_with_fuzzy_and_criteria(self): + qs = Article.objects.annotate( + score=SearchText( + path="body", query="lazzy", fuzzy={"maxEdits": 1}, match_criteria="all" + ) + ) + self.assertIn("lazy", qs.first().body) + + +class SearchWildcardTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "wildcard_headline_index", + { + "mappings": { + "dynamic": False, + "fields": {"headline": {"type": "string", "analyzer": "lucene.keyword"}}, + } + }, + ) + Article.objects.create(headline="dark-knight", number=1, body="") + time.sleep(1) + + def test_search_wildcard(self): + qs = Article.objects.annotate(score=SearchWildcard(path="headline", query="dark-*")) + self.assertIn("dark", qs.first().headline) + + +class SearchGeoShapeTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "geoshape_location_index", + { + "mappings": { + "dynamic": False, + "fields": {"location": {"type": "geo", "indexShapes": True}}, + } + }, + ) + Article.objects.create( + headline="any", number=1, body="", location={"type": "Point", "coordinates": [40, 5]} + ) + time.sleep(1) + + def test_search_geo_shape(self): + polygon = { + "type": "Polygon", + "coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]], + } + qs = Article.objects.annotate( + score=SearchGeoShape(path="location", relation="within", geometry=polygon) + ) + self.assertEqual(qs.first().number, 1) + + +class SearchGeoWithinTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "geowithin_location_index", + {"mappings": {"dynamic": False, "fields": {"location": {"type": "geo"}}}}, + ) + Article.objects.create( + headline="geo", number=2, body="", location={"type": "Point", "coordinates": [40, 5]} + ) + time.sleep(1) + + def test_search_geo_within(self): + polygon = { + "type": "Polygon", + "coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]], + } + qs = Article.objects.annotate( + score=SearchGeoWithin( + path="location", + kind="geometry", + geo_object=polygon, + ) + ) + self.assertEqual(qs.first().number, 2) From 2768e248393a0b6bc5ab2a52829828c5e4eb74b1 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 29 Jun 2025 22:30:30 -0300 Subject: [PATCH 07/37] Add moreLikeThis lookup. --- django_mongodb_backend/compiler.py | 27 ++++++------ .../expressions/builtins.py | 2 + tests/queries_/test_search.py | 44 ++++++++++++++++++- 3 files changed, 58 insertions(+), 15 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 41f278ea0..9f79f0994 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -109,36 +109,35 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias) return replacements, group - def _prepare_search_expressions_for_pipeline(self, expression, target, search_idx): + def _prepare_search_expressions_for_pipeline( + self, expression, target, search_idx, replacements + ): searches = {} - replacements = {} for sub_expr in self._get_search_expressions(expression): - alias = f"__search_expr.search{next(search_idx)}" - replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias) - return replacements, list(searches.values()) + if sub_expr not in replacements: + alias = f"__search_expr.search{next(search_idx)}" + replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias) + return list(searches.values()) def _prepare_search_query_for_aggregation_pipeline(self, order_by): replacements = {} searches = [] annotation_group_idx = itertools.count(start=1) for target, expr in self.query.annotation_select.items(): - new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline( - expr, target, annotation_group_idx + expr_searches = self._prepare_search_expressions_for_pipeline( + expr, target, annotation_group_idx, replacements ) - replacements.update(new_replacements) searches += expr_searches for expr, _ in order_by: - new_replacements, expr_searches = self._prepare_search_expressions_for_pipeline( - expr, None, annotation_group_idx + expr_searches = self._prepare_search_expressions_for_pipeline( + expr, None, annotation_group_idx, replacements ) - replacements.update(new_replacements) searches += expr_searches - having_replacements, having_group = self._prepare_search_expressions_for_pipeline( - self.having, None, annotation_group_idx + having_group = self._prepare_search_expressions_for_pipeline( + self.having, None, annotation_group_idx, replacements ) - replacements.update(having_replacements) searches += having_group return searches, replacements diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 60dc1bd61..c76cf8716 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -485,6 +485,8 @@ def as_mql(self, compiler, connection): class SearchMoreLikeThis(SearchExpression): + search_type = "more_like_this" + def __init__(self, documents, score=None): self.documents = documents self.score = score diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index b8d3092f0..ddb7d33a4 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -11,6 +11,7 @@ SearchGeoShape, SearchGeoWithin, SearchIn, + SearchMoreLikeThis, SearchPhrase, SearchRange, SearchRegex, @@ -27,7 +28,7 @@ def _get_collection(self, model): def create_search_index(self, model, index_name, definition): collection = self._get_collection(model) - idx = SearchIndexModel(definition=definition, name="test_index") + idx = SearchIndexModel(definition=definition, name=index_name) collection.create_search_index(idx) @@ -254,3 +255,44 @@ def test_search_geo_within(self): ) ) self.assertEqual(qs.first().number, 2) + + +class SearchMoreLikeThisTest(TestCase, CreateIndexMixin): + def setUp(self): + self.create_search_index( + Article, + "mlt_index", + { + "mappings": { + "dynamic": False, + "fields": {"body": {"type": "string"}, "headline": {"type": "string"}}, + } + }, + ) + self.article1 = Article.objects.create( + headline="Space exploration", number=1, body="Webb telescope" + ) + self.article2 = Article.objects.create( + headline="The commodities fall", + number=2, + body="Commodities dropped sharply due to inflation concerns", + ) + Article.objects.create( + headline="irrelevant", + number=3, + body="This is a completely unrelated article about cooking", + ) + time.sleep(1) + + def test_search_more_like_this(self): + like_docs = [ + {"headline": self.article1.headline, "body": self.article1.body}, + {"headline": self.article2.headline, "body": self.article2.body}, + ] + like_docs = [{"body": "NASA launches new satellite to explore the galaxy"}] + qs = Article.objects.annotate(score=SearchMoreLikeThis(documents=like_docs)).order_by( + "score" + ) + self.assertQuerySetEqual( + qs, ["space exploration", "The commodities fall"], lambda a: a.headline + ) From 228ea1feb0be4b24d397015438ee4891ea6353da Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 1 Jul 2025 00:44:40 -0300 Subject: [PATCH 08/37] CombinedSearchExpression --- .../expressions/builtins.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index c76cf8716..52c51d931 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -519,6 +519,35 @@ def __init__(self, lhs, connector, rhs, output_field=None): self.lhs = lhs self.rhs = rhs + def as_mql(self, compiler, connection): + if self.connector == self.AND: + return CompoundExpression(must=[self.lhs, self.rhs]) + if self.connector == self.NEGATION: + return CompoundExpression(must_must=[self.lhs]) + raise ValueError(":)") + + def __invert__(self): + # SHOULD BE MOVED TO THE PARENT + return self + + +class CompoundExpression(SearchExpression): + def __init__(self, must=None, must_not=None, should=None, filter=None, score=None): + self.must = must + self.must_not = must_not + self.should = should + self.filter = filter + self.score = score + + def as_mql(self, compiler, connection): + params = {} + for param in ["must", "must_not", "should", "filter"]: + clauses = getattr(self, param) + if clauses: + params[param] = [clause.as_mql(compiler, connection) for clause in clauses] + + return {"$compound": params} + def register_expressions(): Case.as_mql = case From c39ab783d333fedec463067fd3bc1a89707f1339 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 1 Jul 2025 23:54:12 -0300 Subject: [PATCH 09/37] Add vector search expression. --- .../expressions/builtins.py | 63 ++++++++++++------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 52c51d931..3e43b7893 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -505,6 +505,42 @@ def as_mql(self, compiler, connection): return {"$search": {"moreLikeThis": params, "index": index}} +class SearchVector(SearchExpression): + def __init__( + self, + path, + query_vector, + index, + limit, + num_candidates=None, + exact=None, + filter=None, + ): + self.path = path + self.query_vector = query_vector + self.index = index + self.limit = limit + self.num_candidates = num_candidates + self.exact = exact + self.filter = filter + super().__init__() + + def as_mql(self, compiler, connection): + params = { + "index": self.index, + "path": self.path, + "queryVector": self.query_vector, + "limit": self.limit, + } + if self.num_candidates is not None: + params["numCandidates"] = self.num_candidates + if self.exact is not None: + params["exact"] = self.exact + if self.filter is not None: + params["filter"] = self.filter + return {"$vectorSearch": params} + + class SearchScoreOption: """Class to mutate scoring on a search operation""" @@ -512,31 +548,12 @@ def __init__(self, definitions=None): self.definitions = definitions -class CombinedSearchExpression(SearchExpression): - def __init__(self, lhs, connector, rhs, output_field=None): - super().__init__(output_field=output_field) - self.connector = connector - self.lhs = lhs - self.rhs = rhs - - def as_mql(self, compiler, connection): - if self.connector == self.AND: - return CompoundExpression(must=[self.lhs, self.rhs]) - if self.connector == self.NEGATION: - return CompoundExpression(must_must=[self.lhs]) - raise ValueError(":)") - - def __invert__(self): - # SHOULD BE MOVED TO THE PARENT - return self - - class CompoundExpression(SearchExpression): def __init__(self, must=None, must_not=None, should=None, filter=None, score=None): - self.must = must - self.must_not = must_not - self.should = should - self.filter = filter + self.must = must or [] + self.must_not = must_not or [] + self.should = should or [] + self.filter = filter or [] self.score = score def as_mql(self, compiler, connection): From dfaefc795471e1164512313eee38e5ddb25791b3 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 5 Jul 2025 17:16:18 -0300 Subject: [PATCH 10/37] Add combinable operators. --- .../expressions/builtins.py | 145 ++++++++++++++++-- 1 file changed, 132 insertions(+), 13 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 3e43b7893..27a2d6714 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -208,7 +208,56 @@ def value(self, compiler, connection): # noqa: ARG001 return value -class SearchExpression(Expression): +class Operator: + AND = "AND" + OR = "OR" + NOT = "NOT" + + def __init__(self, operator): + self.operator = operator + + def __eq__(self, other): + if isinstance(other, str): + return self.operator == other + return self.operator == other.operator + + def negate(self): + if self.operator == self.AND: + return Operator(self.OR) + if self.operator == self.OR: + return Operator(self.AND) + return Operator(self.operator) + + +class SearchCombinable: + def _combine(self, other, connector, reversed): + if not isinstance(self, CompoundExpression | CombinedSearchExpression): + lhs = CompoundExpression(must=[self]) + else: + lhs = self + if not isinstance(other, CompoundExpression | CombinedSearchExpression): + rhs = CompoundExpression(must=[other]) + else: + rhs = other + return CombinedSearchExpression(lhs, connector, rhs) + + def __invert__(self): + return CombinedSearchExpression(self, Operator(Operator.NOT), None) + + def __and__(self, other): + return CombinedSearchExpression(self, Operator(Operator.AND), other) + + def __rand__(self, other): + return CombinedSearchExpression(self, Operator(Operator.AND), other) + + def __or__(self, other): + return CombinedSearchExpression(self, Operator(Operator.OR), other) + + def __ror__(self, other): + return CombinedSearchExpression(self, Operator(Operator.OR), other) + + +class SearchExpression(SearchCombinable, Expression): output_field = FloatField() def get_source_expressions(self): @@ -525,6 +574,21 @@ def __init__( self.filter = filter super().__init__() + def __invert__(self): + return ValueError("SearchVector cannot be negated") + + def __and__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __rand__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __or__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __ror__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + def as_mql(self, compiler, connection): params = { "index": self.index, @@ -541,15 +605,16 @@ def as_mql(self, compiler, connection): return {"$vectorSearch": params} -class SearchScoreOption: - """Class to mutate scoring on a search operation""" - - def __init__(self, definitions=None): - self.definitions = definitions - - class CompoundExpression(SearchExpression): - def __init__(self, must=None, must_not=None, should=None, filter=None, score=None): + def __init__( + self, + must=None, + must_not=None, + should=None, + filter=None, + score=None, + minimum_should_match=None, + ): self.must = must or [] self.must_not = must_not or [] self.should = should or [] @@ -558,13 +623,67 @@ def __init__(self, must=None, must_not=None, should=None, filter=None, score=Non def as_mql(self, compiler, connection): params = {} - for param in ["must", "must_not", "should", "filter"]: - clauses = getattr(self, param) - if clauses: - params[param] = [clause.as_mql(compiler, connection) for clause in clauses] + if self.must: + params["must"] = [clause.as_mql(compiler, connection) for clause in self.must] + if self.must_not: + params["mustNot"] = [clause.as_mql(compiler, connection) for clause in self.must_not] + if self.should: + params["should"] = [clause.as_mql(compiler, connection) for clause in self.should] + if self.filter: + params["filter"] = [clause.as_mql(compiler, connection) for clause in self.filter] + if self.minimum_should_match is not None: + params["minimumShouldMatch"] = self.minimum_should_match return {"$compound": params} + def negate(self): + return CompoundExpression(must=self.must_not, must_not=self.must + self.filter) + + +class CombinedSearchExpression(SearchExpression): + def __init__(self, lhs, operator, rhs): + self.lhs = lhs + self.operator = operator + self.rhs = rhs + + @staticmethod + def _flatten(node, negated=False): + if node is None: + return None + # Leaf, resolve the compoundExpression + if isinstance(node, CompoundExpression): + return node.negate() if negated else node + # Apply De Morgan's Laws. + operator = node.operator.negate() if negated else node.operator + negated = negated != (node.operator == Operator.NOT) + lhs_compound = node._flatten(node.lhs, negated) + rhs_compound = node._flatten(node.rhs, negated) + if operator == Operator.OR: + return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1) + if node.operator == Operator.AND: + return CompoundExpression( + must=lhs_compound.must + rhs_compound.must, + must_not=lhs_compound.must_not + rhs_compound.must_not, + should=lhs_compound.should + rhs_compound.should, + filter=lhs_compound.filter + rhs_compound.filter, + ) + # it also can be written as: + # this way is more consistent with OR, but the above is shorter in the debug query. + # return CompoundExpression(must=[lhs_compound, rhs_compound]) + # not operator + return lhs_compound + + def as_mql(self, compiler, connection): + expression = self._flatten(self) + return expression.as_mql(compiler, connection) + + +class SearchScoreOption: + """Class to mutate scoring on a search operation""" + + def __init__(self, definitions=None): + self.definitions = definitions + def register_expressions(): Case.as_mql = case From 9249e059a7a469e742531bcdb9a6bc90268d3371 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 6 Jul 2025 11:47:27 -0300 Subject: [PATCH 11/37] Edits --- .../expressions/builtins.py | 173 ++++++++++++------ tests/queries_/test_search.py | 75 +++++++- 2 files changed, 188 insertions(+), 60 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 27a2d6714..d110ab8a6 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -228,33 +228,36 @@ def negate(self): return Operator(self.AND) return Operator(self.operator) + def __hash__(self): + return hash(self.operator) + class SearchCombinable: - def _combine(self, other, connector, reversed): + def _combine(self, other, connector): if not isinstance(self, CompoundExpression | CombinedSearchExpression): lhs = CompoundExpression(must=[self]) else: lhs = self - if not isinstance(other, CompoundExpression | CombinedSearchExpression): + if other and not isinstance(other, CompoundExpression | CombinedSearchExpression): rhs = CompoundExpression(must=[other]) else: rhs = other return CombinedSearchExpression(lhs, connector, rhs) def __invert__(self): - return CombinedSearchExpression(self, Operator(Operator.NOT), None) + return self._combine(None, Operator(Operator.NOT)) def __and__(self, other): - return CombinedSearchExpression(self, Operator(Operator.AND), other) + return self._combine(other, Operator(Operator.AND)) def __rand__(self, other): - return CombinedSearchExpression(self, Operator(Operator.AND), other) + return self._combine(other, Operator(Operator.AND)) def __or__(self, other): - return CombinedSearchExpression(self, Operator(Operator.OR), other) + return self._combine(other, Operator(Operator.OR)) def __ror__(self, other): - return CombinedSearchExpression(self, Operator(Operator.OR), other) + return self._combine(self, Operator(Operator.OR), other) class SearchExpression(SearchCombinable, Expression): @@ -281,6 +284,13 @@ def _get_query_index(self, fields, compiler): return search_indexes["name"] return "default" + def search_operator(self, compiler, connection): + raise NotImplementedError + + def as_mql(self, compiler, connection): + index = self._get_query_index(self.get_search_fields(), compiler) + return {"$search": {**self.search_operator(compiler, connection), "index": index}} + class SearchAutocomplete(SearchExpression): def __init__(self, path, query, fuzzy=None, score=None): @@ -290,7 +300,10 @@ def __init__(self, path, query, fuzzy=None, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, "query": self.query, @@ -299,26 +312,29 @@ def as_mql(self, compiler, connection): params["score"] = self.score if self.fuzzy is not None: params["fuzzy"] = self.fuzzy - index = self._get_query_index([self.path], compiler) - return {"$search": {"autocomplete": params, "index": index}} + return {"autocomplete": params} class SearchEquals(SearchExpression): + search_type = "equals" + def __init__(self, path, value, score=None): self.path = path self.value = value self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, "value": self.value, } if self.score is not None: params["score"] = self.score - index = self._get_query_index([self.path], compiler) - return {"$search": {"equals": params, "index": index}} + return {"equals": params} class SearchExists(SearchExpression): @@ -327,14 +343,16 @@ def __init__(self, path, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, } if self.score is not None: params["score"] = self.score - index = self._get_query_index([self.path], compiler) - return {"$search": {"exists": params, "index": index}} + return {"exists": params} class SearchIn(SearchExpression): @@ -344,15 +362,17 @@ def __init__(self, path, value, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, "value": self.value, } if self.score is not None: params["score"] = self.score - index = self._get_query_index([self.path], compiler) - return {"$search": {"in": params, "index": index}} + return {"in": params} class SearchPhrase(SearchExpression): @@ -364,7 +384,10 @@ def __init__(self, path, query, slop=None, synonyms=None, score=None): self.synonyms = synonyms super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, "query": self.query, @@ -375,8 +398,7 @@ def as_mql(self, compiler, connection): params["slop"] = self.slop if self.synonyms is not None: params["synonyms"] = self.synonyms - index = self._get_query_index([self.path], compiler) - return {"$search": {"phrase": params, "index": index}} + return {"phrase": params} class SearchQueryString(SearchExpression): @@ -386,15 +408,17 @@ def __init__(self, path, query, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "defaultPath": self.path, "query": self.query, } if self.score is not None: params["score"] = self.score - index = self._get_query_index([self.path], compiler) - return {"$search": {"queryString": params, "index": index}} + return {"queryString": params} class SearchRange(SearchExpression): @@ -407,7 +431,10 @@ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, } @@ -421,8 +448,7 @@ def as_mql(self, compiler, connection): params["gt"] = self.gt if self.gte is not None: params["gte"] = self.gte - index = self._get_query_index([self.path], compiler) - return {"$search": {"range": params, "index": index}} + return {"range": params} class SearchRegex(SearchExpression): @@ -433,7 +459,10 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, "query": self.query, @@ -442,8 +471,7 @@ def as_mql(self, compiler, connection): params["score"] = self.score if self.allow_analyzed_field is not None: params["allowAnalyzedField"] = self.allow_analyzed_field - index = self._get_query_index([self.path], compiler) - return {"$search": {"regex": params, "index": index}} + return {"regex": params} class SearchText(SearchExpression): @@ -456,7 +484,10 @@ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, "query": self.query, @@ -469,8 +500,7 @@ def as_mql(self, compiler, connection): params["matchCriteria"] = self.match_criteria if self.synonyms is not None: params["synonyms"] = self.synonyms - index = self._get_query_index([self.path], compiler) - return {"$search": {"text": params, "index": index}} + return {"text": params} class SearchWildcard(SearchExpression): @@ -481,7 +511,10 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, "query": self.query, @@ -490,8 +523,7 @@ def as_mql(self, compiler, connection): params["score"] = self.score if self.allow_analyzed_field is not None: params["allowAnalyzedField"] = self.allow_analyzed_field - index = self._get_query_index([self.path], compiler) - return {"$search": {"wildcard": params, "index": index}} + return {"wildcard": params} class SearchGeoShape(SearchExpression): @@ -502,7 +534,10 @@ def __init__(self, path, relation, geometry, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def get_search_fields(self): + return {self.path} + + def search_operator(self, compiler, connection): params = { "path": self.path, "relation": self.relation, @@ -510,8 +545,7 @@ def as_mql(self, compiler, connection): } if self.score: params["score"] = self.score - index = self._get_query_index([self.path], compiler) - return {"$search": {"geoShape": params, "index": index}} + return {"geoShape": params} class SearchGeoWithin(SearchExpression): @@ -522,15 +556,17 @@ def __init__(self, path, kind, geo_object, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def search_operator(self, compiler, connection): params = { "path": self.path, self.kind: self.geo_object, } if self.score: params["score"] = self.score - index = self._get_query_index([self.path], compiler) - return {"$search": {"geoWithin": params, "index": index}} + return {"geoWithin": params} + + def get_search_fields(self): + return {self.path} class SearchMoreLikeThis(SearchExpression): @@ -541,17 +577,19 @@ def __init__(self, documents, score=None): self.score = score super().__init__() - def as_mql(self, compiler, connection): + def search_operator(self, compiler, connection): params = { "like": self.documents, } if self.score: params["score"] = self.score - needed_fields = [] + return {"moreLikeThis": params} + + def get_search_fields(self): + needed_fields = set() for doc in self.documents: - needed_fields += list(doc.keys()) - index = self._get_query_index(needed_fields, compiler) - return {"$search": {"moreLikeThis": params, "index": index}} + needed_fields.update(set(doc.keys())) + return needed_fields class SearchVector(SearchExpression): @@ -559,7 +597,6 @@ def __init__( self, path, query_vector, - index, limit, num_candidates=None, exact=None, @@ -567,7 +604,6 @@ def __init__( ): self.path = path self.query_vector = query_vector - self.index = index self.limit = limit self.num_candidates = num_candidates self.exact = exact @@ -589,9 +625,15 @@ def __or__(self, other): def __ror__(self, other): raise NotSupportedError("SearchVector cannot be combined") + def get_search_fields(self): + return {self.path} + + def _get_query_index(self, field, compiler): + return "default" + def as_mql(self, compiler, connection): params = { - "index": self.index, + "index": self._get_query_index(self.get_search_fields()), "path": self.path, "queryVector": self.query_vector, "limit": self.limit, @@ -606,6 +648,8 @@ def as_mql(self, compiler, connection): class CompoundExpression(SearchExpression): + search_type = "compound" + def __init__( self, must=None, @@ -620,27 +664,42 @@ def __init__( self.should = should or [] self.filter = filter or [] self.score = score + self.minimum_should_match = minimum_should_match - def as_mql(self, compiler, connection): + def get_search_fields(self): + fields = set() + for clause in self.must + self.should + self.filter + self.must_not: + fields.update(clause.get_search_fields()) + return fields + + def search_operator(self, compiler, connection): params = {} if self.must: - params["must"] = [clause.as_mql(compiler, connection) for clause in self.must] + params["must"] = [clause.search_operator(compiler, connection) for clause in self.must] if self.must_not: - params["mustNot"] = [clause.as_mql(compiler, connection) for clause in self.must_not] + params["mustNot"] = [ + clause.search_operator(compiler, connection) for clause in self.must_not + ] if self.should: - params["should"] = [clause.as_mql(compiler, connection) for clause in self.should] + params["should"] = [ + clause.search_operator(compiler, connection) for clause in self.should + ] if self.filter: - params["filter"] = [clause.as_mql(compiler, connection) for clause in self.filter] + params["filter"] = [ + clause.search_operator(compiler, connection) for clause in self.filter + ] if self.minimum_should_match is not None: params["minimumShouldMatch"] = self.minimum_should_match - return {"$compound": params} + return {"compound": params} def negate(self): return CompoundExpression(must=self.must_not, must_not=self.must + self.filter) class CombinedSearchExpression(SearchExpression): + search_type = "combined" + def __init__(self, lhs, operator, rhs): self.lhs = lhs self.operator = operator diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index ddb7d33a4..71e636b92 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -5,6 +5,7 @@ from pymongo.operations import SearchIndexModel from django_mongodb_backend.expressions.builtins import ( + CompoundExpression, SearchAutocomplete, SearchEquals, SearchExists, @@ -23,11 +24,13 @@ class CreateIndexMixin: - def _get_collection(self, model): + @staticmethod + def _get_collection(model): return connection.database.get_collection(model._meta.db_table) - def create_search_index(self, model, index_name, definition): - collection = self._get_collection(model) + @staticmethod + def create_search_index(model, index_name, definition): + collection = CreateIndexMixin._get_collection(model) idx = SearchIndexModel(definition=definition, name=index_name) collection.create_search_index(idx) @@ -296,3 +299,69 @@ def test_search_more_like_this(self): self.assertQuerySetEqual( qs, ["space exploration", "The commodities fall"], lambda a: a.headline ) + + +class CompoundSearchTest(TestCase, CreateIndexMixin): + @classmethod + def setUpTestData(cls): + cls.create_search_index( + Article, + "compound_index", + { + "mappings": { + "dynamic": False, + "fields": { + "headline": {"type": "token"}, + "body": {"type": "string"}, + "number": {"type": "number"}, + }, + } + }, + ) + cls.mars_mission = Article.objects.create( + number=1, + headline="space exploration", + body="NASA launches a new mission to Mars, aiming to study surface geology", + ) + + cls.exoplanet = Article.objects.create( + number=2, + headline="space exploration", + body="Astronomers discover exoplanets orbiting distant stars using Webb telescope", + ) + + cls.icy_moons = Article.objects.create( + number=3, + headline="space exploration", + body="ESA prepares a robotic expedition to explore the icy moons of Jupiter", + ) + + cls.comodities_drop = Article.objects.create( + number=4, + headline="astronomy news", + body="Commodities dropped sharply due to inflation concerns", + ) + + time.sleep(1) + + def test_compound_expression(self): + must_expr = SearchEquals(path="headline", value="space exploration") + must_not_expr = SearchPhrase(path="body", query="icy moons") + should_expr = SearchPhrase(path="body", query="exoplanets") + + compound = CompoundExpression( + must=[must_expr or should_expr], + must_not=[must_not_expr], + should=[should_expr], + minimum_should_match=1, + ) + + qs = Article.objects.annotate(score=compound).order_by("score") + self.assertCountEqual(qs, [self.exoplanet]) + + def test_compound_operations(self): + expr = SearchEquals(path="headline", value="space exploration") & ~SearchEquals( + path="number", value=3 + ) + qs = Article.objects.annotate(score=expr) + self.assertCountEqual(qs, [self.mars_mission, self.exoplanet]) From c1e949344455bea4e0b905b4c47267c9d5d1e46e Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 6 Jul 2025 12:16:57 -0300 Subject: [PATCH 12/37] Add __str__ method --- django_mongodb_backend/expressions/builtins.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index d110ab8a6..193d996c3 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -263,12 +263,11 @@ def __ror__(self, other): class SearchExpression(SearchCombinable, Expression): output_field = FloatField() - def get_source_expressions(self): - return [] - def __str__(self): - args = ", ".join(map(str, self.get_source_expressions())) - return f"{self.search_type}({args})" + cls = self.identity[0] + kwargs = dict(self.identity[1:]) + arg_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) + return f"{cls.__name__}({arg_str})" def __repr__(self): return str(self) @@ -316,8 +315,6 @@ def search_operator(self, compiler, connection): class SearchEquals(SearchExpression): - search_type = "equals" - def __init__(self, path, value, score=None): self.path = path self.value = value @@ -570,8 +567,6 @@ def get_search_fields(self): class SearchMoreLikeThis(SearchExpression): - search_type = "more_like_this" - def __init__(self, documents, score=None): self.documents = documents self.score = score @@ -648,8 +643,6 @@ def as_mql(self, compiler, connection): class CompoundExpression(SearchExpression): - search_type = "compound" - def __init__( self, must=None, @@ -698,8 +691,6 @@ def negate(self): class CombinedSearchExpression(SearchExpression): - search_type = "combined" - def __init__(self, lhs, operator, rhs): self.lhs = lhs self.operator = operator From a32a8dedb6916123de077198440a545177d78c4c Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 6 Jul 2025 17:55:30 -0300 Subject: [PATCH 13/37] Add combined expressions test. --- .../expressions/builtins.py | 81 +++++++++---------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 193d996c3..1da7ea198 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -231,6 +231,12 @@ def negate(self): def __hash__(self): return hash(self.operator) + def __str__(self): + return self.operator + + def __repr__(self): + return self.operator + class SearchCombinable: def _combine(self, other, connector): @@ -283,12 +289,12 @@ def _get_query_index(self, fields, compiler): return search_indexes["name"] return "default" - def search_operator(self, compiler, connection): + def search_operator(self): raise NotImplementedError def as_mql(self, compiler, connection): index = self._get_query_index(self.get_search_fields(), compiler) - return {"$search": {**self.search_operator(compiler, connection), "index": index}} + return {"$search": {**self.search_operator(), "index": index}} class SearchAutocomplete(SearchExpression): @@ -302,7 +308,7 @@ def __init__(self, path, query, fuzzy=None, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, "query": self.query, @@ -324,7 +330,7 @@ def __init__(self, path, value, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, "value": self.value, @@ -343,7 +349,7 @@ def __init__(self, path, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, } @@ -362,7 +368,7 @@ def __init__(self, path, value, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, "value": self.value, @@ -384,7 +390,7 @@ def __init__(self, path, query, slop=None, synonyms=None, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, "query": self.query, @@ -408,7 +414,7 @@ def __init__(self, path, query, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "defaultPath": self.path, "query": self.query, @@ -431,7 +437,7 @@ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, } @@ -459,7 +465,7 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, "query": self.query, @@ -484,7 +490,7 @@ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, "query": self.query, @@ -511,7 +517,7 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, "query": self.query, @@ -534,7 +540,7 @@ def __init__(self, path, relation, geometry, score=None): def get_search_fields(self): return {self.path} - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, "relation": self.relation, @@ -553,7 +559,7 @@ def __init__(self, path, kind, geo_object, score=None): self.score = score super().__init__() - def search_operator(self, compiler, connection): + def search_operator(self): params = { "path": self.path, self.kind: self.geo_object, @@ -572,7 +578,7 @@ def __init__(self, documents, score=None): self.score = score super().__init__() - def search_operator(self, compiler, connection): + def search_operator(self): params = { "like": self.documents, } @@ -665,29 +671,23 @@ def get_search_fields(self): fields.update(clause.get_search_fields()) return fields - def search_operator(self, compiler, connection): + def search_operator(self): params = {} if self.must: - params["must"] = [clause.search_operator(compiler, connection) for clause in self.must] + params["must"] = [clause.search_operator() for clause in self.must] if self.must_not: - params["mustNot"] = [ - clause.search_operator(compiler, connection) for clause in self.must_not - ] + params["mustNot"] = [clause.search_operator() for clause in self.must_not] if self.should: - params["should"] = [ - clause.search_operator(compiler, connection) for clause in self.should - ] + params["should"] = [clause.search_operator() for clause in self.should] if self.filter: - params["filter"] = [ - clause.search_operator(compiler, connection) for clause in self.filter - ] + params["filter"] = [clause.search_operator() for clause in self.filter] if self.minimum_should_match is not None: params["minimumShouldMatch"] = self.minimum_should_match return {"compound": params} def negate(self): - return CompoundExpression(must=self.must_not, must_not=self.must + self.filter) + return CompoundExpression(must_not=[self]) class CombinedSearchExpression(SearchExpression): @@ -697,7 +697,7 @@ def __init__(self, lhs, operator, rhs): self.rhs = rhs @staticmethod - def _flatten(node, negated=False): + def resolve(node, negated=False): if node is None: return None # Leaf, resolve the compoundExpression @@ -706,25 +706,24 @@ def _flatten(node, negated=False): # Apply De Morgan's Laws. operator = node.operator.negate() if negated else node.operator negated = negated != (node.operator == Operator.NOT) - lhs_compound = node._flatten(node.lhs, negated) - rhs_compound = node._flatten(node.rhs, negated) + lhs_compound = node.resolve(node.lhs, negated) + rhs_compound = node.resolve(node.rhs, negated) if operator == Operator.OR: return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1) - if node.operator == Operator.AND: - return CompoundExpression( - must=lhs_compound.must + rhs_compound.must, - must_not=lhs_compound.must_not + rhs_compound.must_not, - should=lhs_compound.should + rhs_compound.should, - filter=lhs_compound.filter + rhs_compound.filter, - ) - # it also can be written as: - # this way is more consistent with OR, but the above is shorter in the debug query. - # return CompoundExpression(must=[lhs_compound, rhs_compound]) + if operator == Operator.AND: + # NOTE: we can't just do the code below, think about this case (A | B) & (C | D) + # return CompoundExpression( + # must=lhs_compound.must + rhs_compound.must, + # must_not=lhs_compound.must_not + rhs_compound.must_not, + # should=lhs_compound.should + rhs_compound.should, + # filter=lhs_compound.filter + rhs_compound.filter, + # ) + return CompoundExpression(must=[lhs_compound, rhs_compound]) # not operator return lhs_compound def as_mql(self, compiler, connection): - expression = self._flatten(self) + expression = self.resolve(self) return expression.as_mql(compiler, connection) From 4845480ea36aa295fe756858df6d9f6b1f8d62b0 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 6 Jul 2025 22:32:13 -0300 Subject: [PATCH 14/37] Add vector search test. --- django_mongodb_backend/compiler.py | 7 +- .../expressions/builtins.py | 125 +++++++++--------- tests/queries_/models.py | 3 +- tests/queries_/test_search.py | 51 ++++++- 4 files changed, 117 insertions(+), 69 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 9f79f0994..71400c3d2 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -248,12 +248,13 @@ def _build_aggregation_pipeline(self, ids, group): pipeline.append({"$unset": "_id"}) return pipeline - def _compound_searches_queries(self, searches): + def _compound_searches_queries(self, searches, search_replacements): if not searches: return [] if len(searches) > 1: raise ValueError("Cannot perform more than one search operation.") - return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": "searchScore"}}}] + score_function = "searchScore" if "$search" in searches[0] else "vectorSearchScore" + return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": score_function}}}] def pre_sql_setup(self, with_col_aliases=False): extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases) @@ -262,7 +263,7 @@ def pre_sql_setup(self, with_col_aliases=False): ) group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) all_replacements = {**search_replacements, **group_replacements} - self.search_pipeline = self._compound_searches_queries(searches) + self.search_pipeline = self._compound_searches_queries(searches, search_replacements) # query.group_by is either: # - None: no GROUP BY # - True: group by select fields diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 1da7ea198..7b019e57b 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -593,61 +593,6 @@ def get_search_fields(self): return needed_fields -class SearchVector(SearchExpression): - def __init__( - self, - path, - query_vector, - limit, - num_candidates=None, - exact=None, - filter=None, - ): - self.path = path - self.query_vector = query_vector - self.limit = limit - self.num_candidates = num_candidates - self.exact = exact - self.filter = filter - super().__init__() - - def __invert__(self): - return ValueError("SearchVector cannot be negated") - - def __and__(self, other): - raise NotSupportedError("SearchVector cannot be combined") - - def __rand__(self, other): - raise NotSupportedError("SearchVector cannot be combined") - - def __or__(self, other): - raise NotSupportedError("SearchVector cannot be combined") - - def __ror__(self, other): - raise NotSupportedError("SearchVector cannot be combined") - - def get_search_fields(self): - return {self.path} - - def _get_query_index(self, field, compiler): - return "default" - - def as_mql(self, compiler, connection): - params = { - "index": self._get_query_index(self.get_search_fields()), - "path": self.path, - "queryVector": self.query_vector, - "limit": self.limit, - } - if self.num_candidates is not None: - params["numCandidates"] = self.num_candidates - if self.exact is not None: - params["exact"] = self.exact - if self.filter is not None: - params["filter"] = self.filter - return {"$vectorSearch": params} - - class CompoundExpression(SearchExpression): def __init__( self, @@ -711,15 +656,7 @@ def resolve(node, negated=False): if operator == Operator.OR: return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1) if operator == Operator.AND: - # NOTE: we can't just do the code below, think about this case (A | B) & (C | D) - # return CompoundExpression( - # must=lhs_compound.must + rhs_compound.must, - # must_not=lhs_compound.must_not + rhs_compound.must_not, - # should=lhs_compound.should + rhs_compound.should, - # filter=lhs_compound.filter + rhs_compound.filter, - # ) return CompoundExpression(must=[lhs_compound, rhs_compound]) - # not operator return lhs_compound def as_mql(self, compiler, connection): @@ -727,6 +664,68 @@ def as_mql(self, compiler, connection): return expression.as_mql(compiler, connection) +class SearchVector(SearchExpression): + def __init__( + self, + path, + query_vector, + limit, + num_candidates=None, + exact=None, + filter=None, + ): + self.path = path + self.query_vector = query_vector + self.limit = limit + self.num_candidates = num_candidates + self.exact = exact + self.filter = filter + super().__init__() + + def __invert__(self): + return ValueError("SearchVector cannot be negated") + + def __and__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __rand__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __or__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __ror__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def get_search_fields(self): + return {self.path} + + def _get_query_index(self, fields, compiler): + for search_indexes in compiler.collection.list_search_indexes(): + if search_indexes["type"] == "vectorSearch": + index_field = { + field["path"] for field in search_indexes["latestDefinition"]["fields"] + } + if fields.issubset(index_field): + return search_indexes["name"] + return "default" + + def as_mql(self, compiler, connection): + params = { + "index": self._get_query_index(self.get_search_fields(), compiler), + "path": self.path, + "queryVector": self.query_vector, + "limit": self.limit, + } + if self.num_candidates is not None: + params["numCandidates"] = self.num_candidates + if self.exact is not None: + params["exact"] = self.exact + if self.filter is not None: + params["filter"] = self.filter + return {"$vectorSearch": params} + + class SearchScoreOption: """Class to mutate scoring on a search operation""" diff --git a/tests/queries_/models.py b/tests/queries_/models.py index fd70b395a..c9ff129cc 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -1,6 +1,6 @@ from django.db import models -from django_mongodb_backend.fields import ObjectIdAutoField, ObjectIdField +from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField class Author(models.Model): @@ -60,3 +60,4 @@ class Article(models.Model): number = models.IntegerField() body = models.TextField() location = models.JSONField(null=True) + plot_embedding = ArrayField(models.FloatField(), size=3) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 71e636b92..8d8866d7d 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -17,6 +17,7 @@ SearchRange, SearchRegex, SearchText, + SearchVector, SearchWildcard, ) @@ -29,9 +30,9 @@ def _get_collection(model): return connection.database.get_collection(model._meta.db_table) @staticmethod - def create_search_index(model, index_name, definition): + def create_search_index(model, index_name, definition, type="search"): collection = CreateIndexMixin._get_collection(model) - idx = SearchIndexModel(definition=definition, name=index_name) + idx = SearchIndexModel(definition=definition, name=index_name, type=type) collection.create_search_index(idx) @@ -365,3 +366,49 @@ def test_compound_operations(self): ) qs = Article.objects.annotate(score=expr) self.assertCountEqual(qs, [self.mars_mission, self.exoplanet]) + + +class SearchVectorTest(TestCase, CreateIndexMixin): + @classmethod + def setUpTestData(cls): + cls.create_search_index( + Article, + "vector_index", + { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 3, + "similarity": "cosine", + "quantization": "scalar", + } + ] + }, + type="vectorSearch", + ) + + cls.mars = Article.objects.create( + headline="Mars landing", + number=1, + body="The rover has landed on Mars", + plot_embedding=[0.1, 0.2, 0.3], + ) + Article.objects.create( + headline="Cooking tips", + number=2, + body="This article is about pasta", + plot_embedding=[0.9, 0.8, 0.7], + ) + time.sleep(1) + + def test_vector_search(self): + vector_query = [0.1, 0.2, 0.3] + expr = SearchVector( + path="plot_embedding", + query_vector=vector_query, + num_candidates=5, + limit=2, + ) + qs = Article.objects.annotate(score=expr).order_by("-score") + self.assertEqual(qs.first(), self.mars) From 152aa46880c934b04bff21d7d04e8e1c7d1b7b39 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 7 Jul 2025 00:15:05 -0300 Subject: [PATCH 15/37] Add combinable test --- .../test_combinable_search_expression.py | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/expressions_/test_combinable_search_expression.py diff --git a/tests/expressions_/test_combinable_search_expression.py b/tests/expressions_/test_combinable_search_expression.py new file mode 100644 index 000000000..7a2fc487f --- /dev/null +++ b/tests/expressions_/test_combinable_search_expression.py @@ -0,0 +1,76 @@ +from django.test import SimpleTestCase + +from django_mongodb_backend.expressions.builtins import ( + CombinedSearchExpression, + CompoundExpression, + SearchEquals, +) + + +class CombinedSearchExpressionResolutionTest(SimpleTestCase): + def test_combined_expression_and_or_not_resolution(self): + A = SearchEquals(path="headline", value="A") + B = SearchEquals(path="headline", value="B") + C = SearchEquals(path="headline", value="C") + D = SearchEquals(path="headline", value="D") + expr = (~A | B) & (C | D) + solved = CombinedSearchExpression.resolve(expr) + self.assertIsInstance(solved, CompoundExpression) + solved_A = CompoundExpression(must_not=[CompoundExpression(must=[A])]) + solved_B = CompoundExpression(must=[B]) + solved_C = CompoundExpression(must=[C]) + solved_D = CompoundExpression(must=[D]) + self.assertCountEqual(solved.must[0].should, [solved_A, solved_B]) + self.assertEqual(solved.must[0].minimum_should_match, 1) + self.assertEqual(solved.must[1].should, [solved_C, solved_D]) + + def test_combined_expression_de_morgans_resolution(self): + A = SearchEquals(path="headline", value="A") + B = SearchEquals(path="headline", value="B") + C = SearchEquals(path="headline", value="C") + D = SearchEquals(path="headline", value="D") + expr = ~(A | B) & (C | D) + solved_A = CompoundExpression(must_not=[CompoundExpression(must=[A])]) + solved_B = CompoundExpression(must_not=[CompoundExpression(must=[B])]) + solved_C = CompoundExpression(must=[C]) + solved_D = CompoundExpression(must=[D]) + solved = CombinedSearchExpression.resolve(expr) + self.assertIsInstance(solved, CompoundExpression) + self.assertCountEqual(solved.must[0].must, [solved_A, solved_B]) + self.assertEqual(solved.must[0].minimum_should_match, None) + self.assertEqual(solved.must[1].should, [solved_C, solved_D]) + self.assertEqual(solved.minimum_should_match, None) + + def test_combined_expression_doble_negation(self): + A = SearchEquals(path="headline", value="A") + expr = ~~A + solved = CombinedSearchExpression.resolve(expr) + solved_A = CompoundExpression(must=[A]) + self.assertIsInstance(solved, CompoundExpression) + self.assertEqual(solved, solved_A) + + def test_combined_expression_long_right_tree(self): + A = SearchEquals(path="headline", value="A") + B = SearchEquals(path="headline", value="B") + C = SearchEquals(path="headline", value="C") + D = SearchEquals(path="headline", value="D") + solved_A = CompoundExpression(must=[A]) + solved_B = CompoundExpression(must_not=[CompoundExpression(must=[B])]) + solved_C = CompoundExpression(must=[C]) + solved_D = CompoundExpression(must=[D]) + expr = A & ~(B & ~(C & D)) + solved = CombinedSearchExpression.resolve(expr) + self.assertIsInstance(solved, CompoundExpression) + self.assertEqual(len(solved.must), 2) + self.assertEqual(solved.must[0], solved_A) + self.assertEqual(len(solved.must[1].should), 2) + self.assertEqual(solved.must[1].should[0], solved_B) + self.assertCountEqual(solved.must[1].should[1].must, [solved_C, solved_D]) + expr = A | ~(B | ~(C | D)) + solved = CombinedSearchExpression.resolve(expr) + self.assertIsInstance(solved, CompoundExpression) + self.assertEqual(len(solved.should), 2) + self.assertEqual(solved.should[0], solved_A) + self.assertEqual(len(solved.should[1].must), 2) + self.assertEqual(solved.should[1].must[0], solved_B) + self.assertCountEqual(solved.should[1].must[1].should, [solved_C, solved_D]) From d19aa10c002f4b473b9956bb9fdf2b1af0ec4609 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 7 Jul 2025 00:15:30 -0300 Subject: [PATCH 16/37] Refactor --- django_mongodb_backend/compiler.py | 47 ++++++++++++++++++------------ 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 71400c3d2..cdb0639a4 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -17,7 +17,7 @@ from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING -from .expressions.builtins import SearchExpression +from .expressions.builtins import SearchExpression, SearchVector from .query import MongoQuery, wrap_database_errors @@ -117,29 +117,24 @@ def _prepare_search_expressions_for_pipeline( if sub_expr not in replacements: alias = f"__search_expr.search{next(search_idx)}" replacements[sub_expr] = self._get_replace_expr(sub_expr, searches, alias) - return list(searches.values()) def _prepare_search_query_for_aggregation_pipeline(self, order_by): replacements = {} - searches = [] annotation_group_idx = itertools.count(start=1) for target, expr in self.query.annotation_select.items(): - expr_searches = self._prepare_search_expressions_for_pipeline( + self._prepare_search_expressions_for_pipeline( expr, target, annotation_group_idx, replacements ) - searches += expr_searches for expr, _ in order_by: - expr_searches = self._prepare_search_expressions_for_pipeline( + self._prepare_search_expressions_for_pipeline( expr, None, annotation_group_idx, replacements ) - searches += expr_searches - having_group = self._prepare_search_expressions_for_pipeline( + self._prepare_search_expressions_for_pipeline( self.having, None, annotation_group_idx, replacements ) - searches += having_group - return searches, replacements + return replacements def _prepare_annotations_for_aggregation_pipeline(self, order_by): """Prepare annotations for the aggregation pipeline.""" @@ -248,22 +243,36 @@ def _build_aggregation_pipeline(self, ids, group): pipeline.append({"$unset": "_id"}) return pipeline - def _compound_searches_queries(self, searches, search_replacements): - if not searches: + def _compound_searches_queries(self, search_replacements): + if not search_replacements: return [] - if len(searches) > 1: + if len(search_replacements) > 1: raise ValueError("Cannot perform more than one search operation.") - score_function = "searchScore" if "$search" in searches[0] else "vectorSearchScore" - return [searches[0], {"$addFields": {"__search_expr.search1": {"$meta": score_function}}}] + pipeline = [] + for search, result_col in search_replacements.items(): + score_function = ( + "vectorSearchScore" if isinstance(search, SearchVector) else "searchScore" + ) + pipeline.extend( + [ + search.as_mql(self, self.connection), + { + "$addFields": { + result_col.as_mql(self, self.connection).removeprefix("$"): { + "$meta": score_function + } + } + }, + ] + ) + return pipeline def pre_sql_setup(self, with_col_aliases=False): extra_select, order_by, group_by = super().pre_sql_setup(with_col_aliases=with_col_aliases) - searches, search_replacements = self._prepare_search_query_for_aggregation_pipeline( - order_by - ) + search_replacements = self._prepare_search_query_for_aggregation_pipeline(order_by) group, group_replacements = self._prepare_annotations_for_aggregation_pipeline(order_by) all_replacements = {**search_replacements, **group_replacements} - self.search_pipeline = self._compound_searches_queries(searches, search_replacements) + self.search_pipeline = self._compound_searches_queries(search_replacements) # query.group_by is either: # - None: no GROUP BY # - True: group by select fields From 995a6b142a46e739dec8fa26137b9362a3396653 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 7 Jul 2025 00:26:14 -0300 Subject: [PATCH 17/37] Remove unused parameter --- django_mongodb_backend/compiler.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index cdb0639a4..627451837 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -109,9 +109,7 @@ def _prepare_expressions_for_pipeline(self, expression, target, annotation_group replacements[sub_expr] = self._get_replace_expr(sub_expr, group, alias) return replacements, group - def _prepare_search_expressions_for_pipeline( - self, expression, target, search_idx, replacements - ): + def _prepare_search_expressions_for_pipeline(self, expression, search_idx, replacements): searches = {} for sub_expr in self._get_search_expressions(expression): if sub_expr not in replacements: @@ -121,18 +119,14 @@ def _prepare_search_expressions_for_pipeline( def _prepare_search_query_for_aggregation_pipeline(self, order_by): replacements = {} annotation_group_idx = itertools.count(start=1) - for target, expr in self.query.annotation_select.items(): - self._prepare_search_expressions_for_pipeline( - expr, target, annotation_group_idx, replacements - ) + for expr in self.query.annotation_select.values(): + self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements) for expr, _ in order_by: - self._prepare_search_expressions_for_pipeline( - expr, None, annotation_group_idx, replacements - ) + self._prepare_search_expressions_for_pipeline(expr, annotation_group_idx, replacements) self._prepare_search_expressions_for_pipeline( - self.having, None, annotation_group_idx, replacements + self.having, annotation_group_idx, replacements ) return replacements From 9a1543c7cc85f720ed8898dc35784fcf5a05d6e9 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Wed, 9 Jul 2025 21:34:24 -0300 Subject: [PATCH 18/37] Fix unit test --- django_mongodb_backend/creation.py | 3 + tests/queries_/models.py | 2 +- tests/queries_/test_search.py | 254 ++++++++++++++++++++--------- 3 files changed, 185 insertions(+), 74 deletions(-) diff --git a/django_mongodb_backend/creation.py b/django_mongodb_backend/creation.py index 572e770fa..4057e786f 100644 --- a/django_mongodb_backend/creation.py +++ b/django_mongodb_backend/creation.py @@ -22,6 +22,9 @@ def _destroy_test_db(self, test_database_name, verbosity): for collection in self.connection.introspection.table_names(): if not collection.startswith("system."): + db_collection = self.connection.database.get_collection(collection) + for search_indexes in db_collection.list_search_indexes(): + db_collection.drop_search_index(search_indexes["name"]) self.connection.database.drop_collection(collection) def create_test_db(self, *args, **kwargs): diff --git a/tests/queries_/models.py b/tests/queries_/models.py index c9ff129cc..df60a6725 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -60,4 +60,4 @@ class Article(models.Model): number = models.IntegerField() body = models.TextField() location = models.JSONField(null=True) - plot_embedding = ArrayField(models.FloatField(), size=3) + plot_embedding = ArrayField(models.FloatField(), size=3, null=True) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 8d8866d7d..19f8db468 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -1,7 +1,10 @@ -import time +import unittest +from collections.abc import Callable +from time import monotonic, sleep from django.db import connection -from django.test import TestCase +from django.db.utils import DatabaseError +from django.test import TransactionTestCase from pymongo.operations import SearchIndexModel from django_mongodb_backend.expressions.builtins import ( @@ -24,34 +27,95 @@ from .models import Article -class CreateIndexMixin: +def _wait_for_assertion(timeout: float = 120, interval: float = 0.5) -> None: + """Generic to block until the predicate returns true + + Args: + timeout (float, optional): Wait time for predicate. Defaults to TIMEOUT. + interval (float, optional): Interval to check predicate. Defaults to DELAY. + + Raises: + AssertionError: _description_ + """ + + @staticmethod + def _inner_wait_loop(predicate: Callable): + """ + Waits until the given predicate stops raising AssertionError or DatabaseError. + + Args: + predicate (Callable): A function that raises AssertionError (or DatabaseError) + if a condition is not yet met. It should refresh its query each time + it's called (e.g., by using `qs.all()` to avoid cached results). + + Raises: + AssertionError or DatabaseError: If the predicate keeps failing beyond the timeout. + """ + start = monotonic() + while True: + try: + predicate() + except (AssertionError, DatabaseError): + if monotonic() - start > timeout: + raise + sleep(interval) + else: + break + + return _inner_wait_loop + + +class SearchUtilsMixin(TransactionTestCase): + available_apps = [] + @staticmethod def _get_collection(model): return connection.database.get_collection(model._meta.db_table) @staticmethod def create_search_index(model, index_name, definition, type="search"): - collection = CreateIndexMixin._get_collection(model) + collection = SearchUtilsMixin._get_collection(model) idx = SearchIndexModel(definition=definition, name=index_name, type=type) collection.create_search_index(idx) + def _tear_down(self, model): + collection = SearchUtilsMixin._get_collection(model) + for search_indexes in collection.list_search_indexes(): + collection.drop_search_index(search_indexes["name"]) + collection.delete_many({}) + + wait_for_assertion = _wait_for_assertion(timeout=3) -class SearchEqualsTest(TestCase, CreateIndexMixin): + +class SearchTest(SearchUtilsMixin): + @classmethod + def setUpTestData(cls): + cls.create_search_index( + Article, + "equals_headline_index", + {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, + ) + + +class SearchEqualsTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, "equals_headline_index", {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, ) - Article.objects.create(headline="cross", number=1, body="body") - time.sleep(1) + self.cross = Article.objects.create(headline="cross", number=1, body="body") def test_search_equals(self): qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")) - self.assertEqual(qs.first().headline, "cross") + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.cross])) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() -class SearchAutocompleteTest(TestCase, CreateIndexMixin): +class SearchAutocompleteTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, @@ -72,60 +136,73 @@ def setUp(self): } }, ) - Article.objects.create(headline="crossing and something", number=2, body="river") + self.article = Article.objects.create( + headline="crossing and something", number=2, body="river" + ) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_autocomplete(self): qs = Article.objects.annotate(score=SearchAutocomplete(path="headline", query="crossing")) - self.assertEqual(qs.first().headline, "crossing and something") + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) -class SearchExistsTest(TestCase, CreateIndexMixin): +class SearchExistsTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, "exists_body_index", {"mappings": {"dynamic": False, "fields": {"body": {"type": "token"}}}}, ) - Article.objects.create(headline="ignored", number=3, body="something") + self.article = Article.objects.create(headline="ignored", number=3, body="something") def test_search_exists(self): qs = Article.objects.annotate(score=SearchExists(path="body")) - self.assertEqual(qs.count(), 1) - self.assertEqual(qs.first().body, "something") + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) -class SearchInTest(TestCase, CreateIndexMixin): +class SearchInTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, "in_headline_index", {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, ) - Article.objects.create(headline="cross", number=1, body="a") + self.cross = Article.objects.create(headline="cross", number=1, body="a") Article.objects.create(headline="road", number=2, body="b") - time.sleep(1) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_in(self): qs = Article.objects.annotate(score=SearchIn(path="headline", value=["cross", "river"])) - self.assertEqual(qs.first().headline, "cross") + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.cross])) -class SearchPhraseTest(TestCase, CreateIndexMixin): +class SearchPhraseTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, "phrase_body_index", {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, ) - Article.objects.create(headline="irrelevant", number=1, body="the quick brown fox") - time.sleep(1) + self.irrelevant = Article.objects.create( + headline="irrelevant", number=1, body="the quick brown fox" + ) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_phrase(self): qs = Article.objects.annotate(score=SearchPhrase(path="body", query="quick brown")) - self.assertIn("quick brown", qs.first().body) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.irrelevant])) -class SearchRangeTest(TestCase, CreateIndexMixin): +class SearchRangeTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, @@ -133,15 +210,18 @@ def setUp(self): {"mappings": {"dynamic": False, "fields": {"number": {"type": "number"}}}}, ) Article.objects.create(headline="x", number=5, body="z") - Article.objects.create(headline="y", number=20, body="z") - time.sleep(1) + self.number20 = Article.objects.create(headline="y", number=20, body="z") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_range(self): qs = Article.objects.annotate(score=SearchRange(path="number", gte=10, lt=30)) - self.assertEqual(qs.first().number, 20) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.number20])) -class SearchRegexTest(TestCase, CreateIndexMixin): +class SearchRegexTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, @@ -153,40 +233,48 @@ def setUp(self): } }, ) - Article.objects.create(headline="hello world", number=1, body="abc") - time.sleep(1) + self.article = Article.objects.create(headline="hello world", number=1, body="abc") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_regex(self): qs = Article.objects.annotate( - score=SearchRegex(path="headline", query="hello.*", allow_analyzed_field=False) + score=SearchRegex(path="headline", query="hello.*", allow_analyzed_field=True) ) - self.assertTrue(qs.first().headline.startswith("hello")) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) -class SearchTextTest(TestCase, CreateIndexMixin): +class SearchTextTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, "text_body_index", {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, ) - Article.objects.create(headline="ignored", number=1, body="The lazy dog sleeps") - time.sleep(1) + self.article = Article.objects.create( + headline="ignored", number=1, body="The lazy dog sleeps" + ) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_text(self): qs = Article.objects.annotate(score=SearchText(path="body", query="lazy")) - self.assertIn("lazy", qs.first().body) + self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) def test_search_text_with_fuzzy_and_criteria(self): qs = Article.objects.annotate( score=SearchText( - path="body", query="lazzy", fuzzy={"maxEdits": 1}, match_criteria="all" + path="body", query="lazzy", fuzzy={"maxEdits": 2}, match_criteria="all" ) ) - self.assertIn("lazy", qs.first().body) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) -class SearchWildcardTest(TestCase, CreateIndexMixin): +class SearchWildcardTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, @@ -198,15 +286,18 @@ def setUp(self): } }, ) - Article.objects.create(headline="dark-knight", number=1, body="") - time.sleep(1) + self.article = Article.objects.create(headline="dark-knight", number=1, body="") + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_wildcard(self): qs = Article.objects.annotate(score=SearchWildcard(path="headline", query="dark-*")) - self.assertIn("dark", qs.first().headline) + self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) -class SearchGeoShapeTest(TestCase, CreateIndexMixin): +class SearchGeoShapeTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, @@ -218,10 +309,13 @@ def setUp(self): } }, ) - Article.objects.create( + self.article = Article.objects.create( headline="any", number=1, body="", location={"type": "Point", "coordinates": [40, 5]} ) - time.sleep(1) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_geo_shape(self): polygon = { @@ -231,20 +325,23 @@ def test_search_geo_shape(self): qs = Article.objects.annotate( score=SearchGeoShape(path="location", relation="within", geometry=polygon) ) - self.assertEqual(qs.first().number, 1) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) -class SearchGeoWithinTest(TestCase, CreateIndexMixin): +class SearchGeoWithinTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, "geowithin_location_index", {"mappings": {"dynamic": False, "fields": {"location": {"type": "geo"}}}}, ) - Article.objects.create( + self.article = Article.objects.create( headline="geo", number=2, body="", location={"type": "Point", "coordinates": [40, 5]} ) - time.sleep(1) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_geo_within(self): polygon = { @@ -258,10 +355,11 @@ def test_search_geo_within(self): geo_object=polygon, ) ) - self.assertEqual(qs.first().number, 2) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) -class SearchMoreLikeThisTest(TestCase, CreateIndexMixin): +@unittest.expectedFailure +class SearchMoreLikeThisTest(SearchUtilsMixin): def setUp(self): self.create_search_index( Article, @@ -286,7 +384,10 @@ def setUp(self): number=3, body="This is a completely unrelated article about cooking", ) - time.sleep(1) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_search_more_like_this(self): like_docs = [ @@ -297,15 +398,16 @@ def test_search_more_like_this(self): qs = Article.objects.annotate(score=SearchMoreLikeThis(documents=like_docs)).order_by( "score" ) - self.assertQuerySetEqual( - qs, ["space exploration", "The commodities fall"], lambda a: a.headline + self.wait_for_assertion( + lambda: self.assertQuerySetEqual( + qs.all(), [self.article1, self.article2], lambda a: a.headline + ) ) -class CompoundSearchTest(TestCase, CreateIndexMixin): - @classmethod - def setUpTestData(cls): - cls.create_search_index( +class CompoundSearchTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( Article, "compound_index", { @@ -319,31 +421,33 @@ def setUpTestData(cls): } }, ) - cls.mars_mission = Article.objects.create( + self.mars_mission = Article.objects.create( number=1, headline="space exploration", body="NASA launches a new mission to Mars, aiming to study surface geology", ) - cls.exoplanet = Article.objects.create( + self.exoplanet = Article.objects.create( number=2, headline="space exploration", body="Astronomers discover exoplanets orbiting distant stars using Webb telescope", ) - cls.icy_moons = Article.objects.create( + self.icy_moons = Article.objects.create( number=3, headline="space exploration", body="ESA prepares a robotic expedition to explore the icy moons of Jupiter", ) - cls.comodities_drop = Article.objects.create( + self.comodities_drop = Article.objects.create( number=4, headline="astronomy news", body="Commodities dropped sharply due to inflation concerns", ) - time.sleep(1) + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_compound_expression(self): must_expr = SearchEquals(path="headline", value="space exploration") @@ -358,20 +462,21 @@ def test_compound_expression(self): ) qs = Article.objects.annotate(score=compound).order_by("score") - self.assertCountEqual(qs, [self.exoplanet]) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.exoplanet])) def test_compound_operations(self): expr = SearchEquals(path="headline", value="space exploration") & ~SearchEquals( path="number", value=3 ) qs = Article.objects.annotate(score=expr) - self.assertCountEqual(qs, [self.mars_mission, self.exoplanet]) + self.wait_for_assertion( + lambda: self.assertCountEqual(qs.all(), [self.mars_mission, self.exoplanet]) + ) -class SearchVectorTest(TestCase, CreateIndexMixin): - @classmethod - def setUpTestData(cls): - cls.create_search_index( +class SearchVectorTest(SearchUtilsMixin): + def setUp(self): + self.create_search_index( Article, "vector_index", { @@ -388,19 +493,22 @@ def setUpTestData(cls): type="vectorSearch", ) - cls.mars = Article.objects.create( + self.mars = Article.objects.create( headline="Mars landing", number=1, body="The rover has landed on Mars", plot_embedding=[0.1, 0.2, 0.3], ) - Article.objects.create( + self.cooking = Article.objects.create( headline="Cooking tips", number=2, body="This article is about pasta", plot_embedding=[0.9, 0.8, 0.7], ) - time.sleep(1) + + def tearDown(self): + self._tear_down(Article) + super().tearDown() def test_vector_search(self): vector_query = [0.1, 0.2, 0.3] @@ -411,4 +519,4 @@ def test_vector_search(self): limit=2, ) qs = Article.objects.annotate(score=expr).order_by("-score") - self.assertEqual(qs.first(), self.mars) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.mars, self.cooking])) From 2df372908d39559ed2f288972b7bf0e24fdb6328 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Wed, 9 Jul 2025 21:41:56 -0300 Subject: [PATCH 19/37] Move search expression to expression/search --- django_mongodb_backend/compiler.py | 2 +- .../expressions/builtins.py | 526 ----------------- django_mongodb_backend/expressions/search.py | 527 ++++++++++++++++++ .../test_combinable_search_expression.py | 2 +- tests/queries_/test_search.py | 2 +- 5 files changed, 530 insertions(+), 529 deletions(-) create mode 100644 django_mongodb_backend/expressions/search.py diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 627451837..a6b373afd 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -17,7 +17,7 @@ from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING -from .expressions.builtins import SearchExpression, SearchVector +from .expressions.search import SearchExpression, SearchVector from .query import MongoQuery, wrap_database_errors diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 7b019e57b..4f6575052 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -5,7 +5,6 @@ from bson import Decimal128 from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import NotSupportedError -from django.db.models import Expression, FloatField from django.db.models.expressions import ( Case, Col, @@ -208,531 +207,6 @@ def value(self, compiler, connection): # noqa: ARG001 return value -class Operator: - AND = "AND" - OR = "OR" - NOT = "NOT" - - def __init__(self, operator): - self.operator = operator - - def __eq__(self, other): - if isinstance(other, str): - return self.operator == other - return self.operator == other.operator - - def negate(self): - if self.operator == self.AND: - return Operator(self.OR) - if self.operator == self.OR: - return Operator(self.AND) - return Operator(self.operator) - - def __hash__(self): - return hash(self.operator) - - def __str__(self): - return self.operator - - def __repr__(self): - return self.operator - - -class SearchCombinable: - def _combine(self, other, connector): - if not isinstance(self, CompoundExpression | CombinedSearchExpression): - lhs = CompoundExpression(must=[self]) - else: - lhs = self - if other and not isinstance(other, CompoundExpression | CombinedSearchExpression): - rhs = CompoundExpression(must=[other]) - else: - rhs = other - return CombinedSearchExpression(lhs, connector, rhs) - - def __invert__(self): - return self._combine(None, Operator(Operator.NOT)) - - def __and__(self, other): - return self._combine(other, Operator(Operator.AND)) - - def __rand__(self, other): - return self._combine(other, Operator(Operator.AND)) - - def __or__(self, other): - return self._combine(other, Operator(Operator.OR)) - - def __ror__(self, other): - return self._combine(self, Operator(Operator.OR), other) - - -class SearchExpression(SearchCombinable, Expression): - output_field = FloatField() - - def __str__(self): - cls = self.identity[0] - kwargs = dict(self.identity[1:]) - arg_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) - return f"{cls.__name__}({arg_str})" - - def __repr__(self): - return str(self) - - def as_sql(self, compiler, connection): - return "", [] - - def _get_query_index(self, fields, compiler): - fields = set(fields) - for search_indexes in compiler.collection.list_search_indexes(): - mappings = search_indexes["latestDefinition"]["mappings"] - if mappings["dynamic"] or fields.issubset(set(mappings["fields"])): - return search_indexes["name"] - return "default" - - def search_operator(self): - raise NotImplementedError - - def as_mql(self, compiler, connection): - index = self._get_query_index(self.get_search_fields(), compiler) - return {"$search": {**self.search_operator(), "index": index}} - - -class SearchAutocomplete(SearchExpression): - def __init__(self, path, query, fuzzy=None, score=None): - self.path = path - self.query = query - self.fuzzy = fuzzy - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - "query": self.query, - } - if self.score is not None: - params["score"] = self.score - if self.fuzzy is not None: - params["fuzzy"] = self.fuzzy - return {"autocomplete": params} - - -class SearchEquals(SearchExpression): - def __init__(self, path, value, score=None): - self.path = path - self.value = value - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - "value": self.value, - } - if self.score is not None: - params["score"] = self.score - return {"equals": params} - - -class SearchExists(SearchExpression): - def __init__(self, path, score=None): - self.path = path - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - } - if self.score is not None: - params["score"] = self.score - return {"exists": params} - - -class SearchIn(SearchExpression): - def __init__(self, path, value, score=None): - self.path = path - self.value = value - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - "value": self.value, - } - if self.score is not None: - params["score"] = self.score - return {"in": params} - - -class SearchPhrase(SearchExpression): - def __init__(self, path, query, slop=None, synonyms=None, score=None): - self.path = path - self.query = query - self.score = score - self.slop = slop - self.synonyms = synonyms - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - "query": self.query, - } - if self.score is not None: - params["score"] = self.score - if self.slop is not None: - params["slop"] = self.slop - if self.synonyms is not None: - params["synonyms"] = self.synonyms - return {"phrase": params} - - -class SearchQueryString(SearchExpression): - def __init__(self, path, query, score=None): - self.path = path - self.query = query - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "defaultPath": self.path, - "query": self.query, - } - if self.score is not None: - params["score"] = self.score - return {"queryString": params} - - -class SearchRange(SearchExpression): - def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): - self.path = path - self.lt = lt - self.lte = lte - self.gt = gt - self.gte = gte - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - } - if self.score is not None: - params["score"] = self.score - if self.lt is not None: - params["lt"] = self.lt - if self.lte is not None: - params["lte"] = self.lte - if self.gt is not None: - params["gt"] = self.gt - if self.gte is not None: - params["gte"] = self.gte - return {"range": params} - - -class SearchRegex(SearchExpression): - def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = path - self.query = query - self.allow_analyzed_field = allow_analyzed_field - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - "query": self.query, - } - if self.score: - params["score"] = self.score - if self.allow_analyzed_field is not None: - params["allowAnalyzedField"] = self.allow_analyzed_field - return {"regex": params} - - -class SearchText(SearchExpression): - def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): - self.path = path - self.query = query - self.fuzzy = fuzzy - self.match_criteria = match_criteria - self.synonyms = synonyms - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - "query": self.query, - } - if self.score: - params["score"] = self.score - if self.fuzzy is not None: - params["fuzzy"] = self.fuzzy - if self.match_criteria is not None: - params["matchCriteria"] = self.match_criteria - if self.synonyms is not None: - params["synonyms"] = self.synonyms - return {"text": params} - - -class SearchWildcard(SearchExpression): - def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = path - self.query = query - self.allow_analyzed_field = allow_analyzed_field - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - "query": self.query, - } - if self.score: - params["score"] = self.score - if self.allow_analyzed_field is not None: - params["allowAnalyzedField"] = self.allow_analyzed_field - return {"wildcard": params} - - -class SearchGeoShape(SearchExpression): - def __init__(self, path, relation, geometry, score=None): - self.path = path - self.relation = relation - self.geometry = geometry - self.score = score - super().__init__() - - def get_search_fields(self): - return {self.path} - - def search_operator(self): - params = { - "path": self.path, - "relation": self.relation, - "geometry": self.geometry, - } - if self.score: - params["score"] = self.score - return {"geoShape": params} - - -class SearchGeoWithin(SearchExpression): - def __init__(self, path, kind, geo_object, score=None): - self.path = path - self.kind = kind - self.geo_object = geo_object - self.score = score - super().__init__() - - def search_operator(self): - params = { - "path": self.path, - self.kind: self.geo_object, - } - if self.score: - params["score"] = self.score - return {"geoWithin": params} - - def get_search_fields(self): - return {self.path} - - -class SearchMoreLikeThis(SearchExpression): - def __init__(self, documents, score=None): - self.documents = documents - self.score = score - super().__init__() - - def search_operator(self): - params = { - "like": self.documents, - } - if self.score: - params["score"] = self.score - return {"moreLikeThis": params} - - def get_search_fields(self): - needed_fields = set() - for doc in self.documents: - needed_fields.update(set(doc.keys())) - return needed_fields - - -class CompoundExpression(SearchExpression): - def __init__( - self, - must=None, - must_not=None, - should=None, - filter=None, - score=None, - minimum_should_match=None, - ): - self.must = must or [] - self.must_not = must_not or [] - self.should = should or [] - self.filter = filter or [] - self.score = score - self.minimum_should_match = minimum_should_match - - def get_search_fields(self): - fields = set() - for clause in self.must + self.should + self.filter + self.must_not: - fields.update(clause.get_search_fields()) - return fields - - def search_operator(self): - params = {} - if self.must: - params["must"] = [clause.search_operator() for clause in self.must] - if self.must_not: - params["mustNot"] = [clause.search_operator() for clause in self.must_not] - if self.should: - params["should"] = [clause.search_operator() for clause in self.should] - if self.filter: - params["filter"] = [clause.search_operator() for clause in self.filter] - if self.minimum_should_match is not None: - params["minimumShouldMatch"] = self.minimum_should_match - - return {"compound": params} - - def negate(self): - return CompoundExpression(must_not=[self]) - - -class CombinedSearchExpression(SearchExpression): - def __init__(self, lhs, operator, rhs): - self.lhs = lhs - self.operator = operator - self.rhs = rhs - - @staticmethod - def resolve(node, negated=False): - if node is None: - return None - # Leaf, resolve the compoundExpression - if isinstance(node, CompoundExpression): - return node.negate() if negated else node - # Apply De Morgan's Laws. - operator = node.operator.negate() if negated else node.operator - negated = negated != (node.operator == Operator.NOT) - lhs_compound = node.resolve(node.lhs, negated) - rhs_compound = node.resolve(node.rhs, negated) - if operator == Operator.OR: - return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1) - if operator == Operator.AND: - return CompoundExpression(must=[lhs_compound, rhs_compound]) - return lhs_compound - - def as_mql(self, compiler, connection): - expression = self.resolve(self) - return expression.as_mql(compiler, connection) - - -class SearchVector(SearchExpression): - def __init__( - self, - path, - query_vector, - limit, - num_candidates=None, - exact=None, - filter=None, - ): - self.path = path - self.query_vector = query_vector - self.limit = limit - self.num_candidates = num_candidates - self.exact = exact - self.filter = filter - super().__init__() - - def __invert__(self): - return ValueError("SearchVector cannot be negated") - - def __and__(self, other): - raise NotSupportedError("SearchVector cannot be combined") - - def __rand__(self, other): - raise NotSupportedError("SearchVector cannot be combined") - - def __or__(self, other): - raise NotSupportedError("SearchVector cannot be combined") - - def __ror__(self, other): - raise NotSupportedError("SearchVector cannot be combined") - - def get_search_fields(self): - return {self.path} - - def _get_query_index(self, fields, compiler): - for search_indexes in compiler.collection.list_search_indexes(): - if search_indexes["type"] == "vectorSearch": - index_field = { - field["path"] for field in search_indexes["latestDefinition"]["fields"] - } - if fields.issubset(index_field): - return search_indexes["name"] - return "default" - - def as_mql(self, compiler, connection): - params = { - "index": self._get_query_index(self.get_search_fields(), compiler), - "path": self.path, - "queryVector": self.query_vector, - "limit": self.limit, - } - if self.num_candidates is not None: - params["numCandidates"] = self.num_candidates - if self.exact is not None: - params["exact"] = self.exact - if self.filter is not None: - params["filter"] = self.filter - return {"$vectorSearch": params} - - -class SearchScoreOption: - """Class to mutate scoring on a search operation""" - - def __init__(self, definitions=None): - self.definitions = definitions - - def register_expressions(): Case.as_mql = case Col.as_mql = col diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py new file mode 100644 index 000000000..0f126d596 --- /dev/null +++ b/django_mongodb_backend/expressions/search.py @@ -0,0 +1,527 @@ +from django.db import NotSupportedError +from django.db.models import Expression, FloatField + + +class Operator: + AND = "AND" + OR = "OR" + NOT = "NOT" + + def __init__(self, operator): + self.operator = operator + + def __eq__(self, other): + if isinstance(other, str): + return self.operator == other + return self.operator == other.operator + + def negate(self): + if self.operator == self.AND: + return Operator(self.OR) + if self.operator == self.OR: + return Operator(self.AND) + return Operator(self.operator) + + def __hash__(self): + return hash(self.operator) + + def __str__(self): + return self.operator + + def __repr__(self): + return self.operator + + +class SearchCombinable: + def _combine(self, other, connector): + if not isinstance(self, CompoundExpression | CombinedSearchExpression): + lhs = CompoundExpression(must=[self]) + else: + lhs = self + if other and not isinstance(other, CompoundExpression | CombinedSearchExpression): + rhs = CompoundExpression(must=[other]) + else: + rhs = other + return CombinedSearchExpression(lhs, connector, rhs) + + def __invert__(self): + return self._combine(None, Operator(Operator.NOT)) + + def __and__(self, other): + return self._combine(other, Operator(Operator.AND)) + + def __rand__(self, other): + return self._combine(other, Operator(Operator.AND)) + + def __or__(self, other): + return self._combine(other, Operator(Operator.OR)) + + def __ror__(self, other): + return self._combine(self, Operator(Operator.OR), other) + + +class SearchExpression(SearchCombinable, Expression): + output_field = FloatField() + + def __str__(self): + cls = self.identity[0] + kwargs = dict(self.identity[1:]) + arg_str = ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) + return f"{cls.__name__}({arg_str})" + + def __repr__(self): + return str(self) + + def as_sql(self, compiler, connection): + return "", [] + + def _get_query_index(self, fields, compiler): + fields = set(fields) + for search_indexes in compiler.collection.list_search_indexes(): + mappings = search_indexes["latestDefinition"]["mappings"] + if mappings["dynamic"] or fields.issubset(set(mappings["fields"])): + return search_indexes["name"] + return "default" + + def search_operator(self): + raise NotImplementedError + + def as_mql(self, compiler, connection): + index = self._get_query_index(self.get_search_fields(), compiler) + return {"$search": {**self.search_operator(), "index": index}} + + +class SearchAutocomplete(SearchExpression): + def __init__(self, path, query, fuzzy=None, score=None): + self.path = path + self.query = query + self.fuzzy = fuzzy + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + "query": self.query, + } + if self.score is not None: + params["score"] = self.score + if self.fuzzy is not None: + params["fuzzy"] = self.fuzzy + return {"autocomplete": params} + + +class SearchEquals(SearchExpression): + def __init__(self, path, value, score=None): + self.path = path + self.value = value + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + "value": self.value, + } + if self.score is not None: + params["score"] = self.score + return {"equals": params} + + +class SearchExists(SearchExpression): + def __init__(self, path, score=None): + self.path = path + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + } + if self.score is not None: + params["score"] = self.score + return {"exists": params} + + +class SearchIn(SearchExpression): + def __init__(self, path, value, score=None): + self.path = path + self.value = value + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + "value": self.value, + } + if self.score is not None: + params["score"] = self.score + return {"in": params} + + +class SearchPhrase(SearchExpression): + def __init__(self, path, query, slop=None, synonyms=None, score=None): + self.path = path + self.query = query + self.score = score + self.slop = slop + self.synonyms = synonyms + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + "query": self.query, + } + if self.score is not None: + params["score"] = self.score + if self.slop is not None: + params["slop"] = self.slop + if self.synonyms is not None: + params["synonyms"] = self.synonyms + return {"phrase": params} + + +class SearchQueryString(SearchExpression): + def __init__(self, path, query, score=None): + self.path = path + self.query = query + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "defaultPath": self.path, + "query": self.query, + } + if self.score is not None: + params["score"] = self.score + return {"queryString": params} + + +class SearchRange(SearchExpression): + def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): + self.path = path + self.lt = lt + self.lte = lte + self.gt = gt + self.gte = gte + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + } + if self.score is not None: + params["score"] = self.score + if self.lt is not None: + params["lt"] = self.lt + if self.lte is not None: + params["lte"] = self.lte + if self.gt is not None: + params["gt"] = self.gt + if self.gte is not None: + params["gte"] = self.gte + return {"range": params} + + +class SearchRegex(SearchExpression): + def __init__(self, path, query, allow_analyzed_field=None, score=None): + self.path = path + self.query = query + self.allow_analyzed_field = allow_analyzed_field + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + "query": self.query, + } + if self.score: + params["score"] = self.score + if self.allow_analyzed_field is not None: + params["allowAnalyzedField"] = self.allow_analyzed_field + return {"regex": params} + + +class SearchText(SearchExpression): + def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): + self.path = path + self.query = query + self.fuzzy = fuzzy + self.match_criteria = match_criteria + self.synonyms = synonyms + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + "query": self.query, + } + if self.score: + params["score"] = self.score + if self.fuzzy is not None: + params["fuzzy"] = self.fuzzy + if self.match_criteria is not None: + params["matchCriteria"] = self.match_criteria + if self.synonyms is not None: + params["synonyms"] = self.synonyms + return {"text": params} + + +class SearchWildcard(SearchExpression): + def __init__(self, path, query, allow_analyzed_field=None, score=None): + self.path = path + self.query = query + self.allow_analyzed_field = allow_analyzed_field + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + "query": self.query, + } + if self.score: + params["score"] = self.score + if self.allow_analyzed_field is not None: + params["allowAnalyzedField"] = self.allow_analyzed_field + return {"wildcard": params} + + +class SearchGeoShape(SearchExpression): + def __init__(self, path, relation, geometry, score=None): + self.path = path + self.relation = relation + self.geometry = geometry + self.score = score + super().__init__() + + def get_search_fields(self): + return {self.path} + + def search_operator(self): + params = { + "path": self.path, + "relation": self.relation, + "geometry": self.geometry, + } + if self.score: + params["score"] = self.score + return {"geoShape": params} + + +class SearchGeoWithin(SearchExpression): + def __init__(self, path, kind, geo_object, score=None): + self.path = path + self.kind = kind + self.geo_object = geo_object + self.score = score + super().__init__() + + def search_operator(self): + params = { + "path": self.path, + self.kind: self.geo_object, + } + if self.score: + params["score"] = self.score + return {"geoWithin": params} + + def get_search_fields(self): + return {self.path} + + +class SearchMoreLikeThis(SearchExpression): + def __init__(self, documents, score=None): + self.documents = documents + self.score = score + super().__init__() + + def search_operator(self): + params = { + "like": self.documents, + } + if self.score: + params["score"] = self.score + return {"moreLikeThis": params} + + def get_search_fields(self): + needed_fields = set() + for doc in self.documents: + needed_fields.update(set(doc.keys())) + return needed_fields + + +class CompoundExpression(SearchExpression): + def __init__( + self, + must=None, + must_not=None, + should=None, + filter=None, + score=None, + minimum_should_match=None, + ): + self.must = must or [] + self.must_not = must_not or [] + self.should = should or [] + self.filter = filter or [] + self.score = score + self.minimum_should_match = minimum_should_match + + def get_search_fields(self): + fields = set() + for clause in self.must + self.should + self.filter + self.must_not: + fields.update(clause.get_search_fields()) + return fields + + def search_operator(self): + params = {} + if self.must: + params["must"] = [clause.search_operator() for clause in self.must] + if self.must_not: + params["mustNot"] = [clause.search_operator() for clause in self.must_not] + if self.should: + params["should"] = [clause.search_operator() for clause in self.should] + if self.filter: + params["filter"] = [clause.search_operator() for clause in self.filter] + if self.minimum_should_match is not None: + params["minimumShouldMatch"] = self.minimum_should_match + + return {"compound": params} + + def negate(self): + return CompoundExpression(must_not=[self]) + + +class CombinedSearchExpression(SearchExpression): + def __init__(self, lhs, operator, rhs): + self.lhs = lhs + self.operator = operator + self.rhs = rhs + + @staticmethod + def resolve(node, negated=False): + if node is None: + return None + # Leaf, resolve the compoundExpression + if isinstance(node, CompoundExpression): + return node.negate() if negated else node + # Apply De Morgan's Laws. + operator = node.operator.negate() if negated else node.operator + negated = negated != (node.operator == Operator.NOT) + lhs_compound = node.resolve(node.lhs, negated) + rhs_compound = node.resolve(node.rhs, negated) + if operator == Operator.OR: + return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1) + if operator == Operator.AND: + return CompoundExpression(must=[lhs_compound, rhs_compound]) + return lhs_compound + + def as_mql(self, compiler, connection): + expression = self.resolve(self) + return expression.as_mql(compiler, connection) + + +class SearchVector(SearchExpression): + def __init__( + self, + path, + query_vector, + limit, + num_candidates=None, + exact=None, + filter=None, + ): + self.path = path + self.query_vector = query_vector + self.limit = limit + self.num_candidates = num_candidates + self.exact = exact + self.filter = filter + super().__init__() + + def __invert__(self): + return ValueError("SearchVector cannot be negated") + + def __and__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __rand__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __or__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def __ror__(self, other): + raise NotSupportedError("SearchVector cannot be combined") + + def get_search_fields(self): + return {self.path} + + def _get_query_index(self, fields, compiler): + for search_indexes in compiler.collection.list_search_indexes(): + if search_indexes["type"] == "vectorSearch": + index_field = { + field["path"] for field in search_indexes["latestDefinition"]["fields"] + } + if fields.issubset(index_field): + return search_indexes["name"] + return "default" + + def as_mql(self, compiler, connection): + params = { + "index": self._get_query_index(self.get_search_fields(), compiler), + "path": self.path, + "queryVector": self.query_vector, + "limit": self.limit, + } + if self.num_candidates is not None: + params["numCandidates"] = self.num_candidates + if self.exact is not None: + params["exact"] = self.exact + if self.filter is not None: + params["filter"] = self.filter + return {"$vectorSearch": params} + + +class SearchScoreOption: + """Class to mutate scoring on a search operation""" + + def __init__(self, definitions=None): + self.definitions = definitions diff --git a/tests/expressions_/test_combinable_search_expression.py b/tests/expressions_/test_combinable_search_expression.py index 7a2fc487f..2ff597050 100644 --- a/tests/expressions_/test_combinable_search_expression.py +++ b/tests/expressions_/test_combinable_search_expression.py @@ -1,6 +1,6 @@ from django.test import SimpleTestCase -from django_mongodb_backend.expressions.builtins import ( +from django_mongodb_backend.expressions.search import ( CombinedSearchExpression, CompoundExpression, SearchEquals, diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 19f8db468..15e1bef1d 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -7,7 +7,7 @@ from django.test import TransactionTestCase from pymongo.operations import SearchIndexModel -from django_mongodb_backend.expressions.builtins import ( +from django_mongodb_backend.expressions.search import ( CompoundExpression, SearchAutocomplete, SearchEquals, From 9efd0ebc6e111fc9c026e245c40172310c6e8023 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Wed, 9 Jul 2025 22:06:02 -0300 Subject: [PATCH 20/37] Improve unit test. --- tests/queries_/test_search.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 15e1bef1d..4cd763747 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -104,11 +104,12 @@ def setUp(self): "equals_headline_index", {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, ) - self.cross = Article.objects.create(headline="cross", number=1, body="body") + self.article = Article.objects.create(headline="cross", number=1, body="body") + Article.objects.create(headline="other thing", number=2, body="body") def test_search_equals(self): qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")) - self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.cross])) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) def tearDown(self): self._tear_down(Article) @@ -139,6 +140,7 @@ def setUp(self): self.article = Article.objects.create( headline="crossing and something", number=2, body="river" ) + Article.objects.create(headline="Some random text", number=3, body="river") def tearDown(self): self._tear_down(Article) @@ -170,7 +172,7 @@ def setUp(self): "in_headline_index", {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, ) - self.cross = Article.objects.create(headline="cross", number=1, body="a") + self.article = Article.objects.create(headline="cross", number=1, body="a") Article.objects.create(headline="road", number=2, body="b") def tearDown(self): @@ -179,7 +181,7 @@ def tearDown(self): def test_search_in(self): qs = Article.objects.annotate(score=SearchIn(path="headline", value=["cross", "river"])) - self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.cross])) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) class SearchPhraseTest(SearchUtilsMixin): @@ -189,9 +191,10 @@ def setUp(self): "phrase_body_index", {"mappings": {"dynamic": False, "fields": {"body": {"type": "string"}}}}, ) - self.irrelevant = Article.objects.create( + self.article = Article.objects.create( headline="irrelevant", number=1, body="the quick brown fox" ) + Article.objects.create(headline="cheetah", number=2, body="fastest animal") def tearDown(self): self._tear_down(Article) @@ -199,7 +202,7 @@ def tearDown(self): def test_search_phrase(self): qs = Article.objects.annotate(score=SearchPhrase(path="body", query="quick brown")) - self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.irrelevant])) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) class SearchRangeTest(SearchUtilsMixin): @@ -234,6 +237,7 @@ def setUp(self): }, ) self.article = Article.objects.create(headline="hello world", number=1, body="abc") + Article.objects.create(headline="hola mundo", number=2, body="abc") def tearDown(self): self._tear_down(Article) @@ -256,6 +260,7 @@ def setUp(self): self.article = Article.objects.create( headline="ignored", number=1, body="The lazy dog sleeps" ) + Article.objects.create(headline="ignored", number=2, body="The sleepy bear") def tearDown(self): self._tear_down(Article) @@ -287,6 +292,7 @@ def setUp(self): }, ) self.article = Article.objects.create(headline="dark-knight", number=1, body="") + Article.objects.create(headline="batman", number=2, body="") def tearDown(self): self._tear_down(Article) @@ -312,6 +318,9 @@ def setUp(self): self.article = Article.objects.create( headline="any", number=1, body="", location={"type": "Point", "coordinates": [40, 5]} ) + Article.objects.create( + headline="any", number=2, body="", location={"type": "Point", "coordinates": [400, 50]} + ) def tearDown(self): self._tear_down(Article) @@ -338,6 +347,9 @@ def setUp(self): self.article = Article.objects.create( headline="geo", number=2, body="", location={"type": "Point", "coordinates": [40, 5]} ) + Article.objects.create( + headline="geo2", number=3, body="", location={"type": "Point", "coordinates": [-40, -5]} + ) def tearDown(self): self._tear_down(Article) From 11311117e866452be8518b42a270f7e6b3abef8e Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 11 Jul 2025 23:44:27 -0300 Subject: [PATCH 21/37] Add expression wrap to parameters. --- .../expressions/builtins.py | 8 +- django_mongodb_backend/expressions/search.py | 188 ++++++++++-------- .../fields/embedded_model.py | 6 +- tests/queries_/models.py | 13 +- tests/queries_/test_search.py | 47 +++-- 5 files changed, 158 insertions(+), 104 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 4f6575052..6f1289c5f 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -53,7 +53,7 @@ def case(self, compiler, connection): } -def col(self, compiler, connection): # noqa: ARG001 +def col(self, compiler, connection, as_path=False): # noqa: ARG001 # If the column is part of a subquery and belongs to one of the parent # queries, it will be stored for reference using $let in a $lookup stage. # If the query is built with `alias_cols=False`, treat the column as @@ -71,7 +71,11 @@ def col(self, compiler, connection): # noqa: ARG001 # Add the column's collection's alias for columns in joined collections. has_alias = self.alias and self.alias != compiler.collection_name prefix = f"{self.alias}." if has_alias else "" - return f"${prefix}{self.target.column}" + return f"{prefix}{self.target.column}" if as_path else f"${prefix}{self.target.column}" + + +def col_as_path(self, compiler, connection): + return col(self, compiler, connection).lstrip("$") def col_pairs(self, compiler, connection): diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 0f126d596..623cc40e2 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -1,5 +1,6 @@ from django.db import NotSupportedError -from django.db.models import Expression, FloatField +from django.db.models import Expression, FloatField, JSONField +from django.db.models.expressions import F, Value class Operator: @@ -75,42 +76,57 @@ def __repr__(self): def as_sql(self, compiler, connection): return "", [] + def _get_indexed_fields(self, mappings): + for field, definition in mappings.get("fields", {}).items(): + yield field + for path in self._get_indexed_fields(definition): + yield f"{field}.{path}" + def _get_query_index(self, fields, compiler): fields = set(fields) for search_indexes in compiler.collection.list_search_indexes(): mappings = search_indexes["latestDefinition"]["mappings"] - if mappings["dynamic"] or fields.issubset(set(mappings["fields"])): + indexed_fields = set(self._get_indexed_fields(mappings)) + if mappings["dynamic"] or fields.issubset(indexed_fields): return search_indexes["name"] return "default" - def search_operator(self): + def search_operator(self, compiler, connection): raise NotImplementedError def as_mql(self, compiler, connection): - index = self._get_query_index(self.get_search_fields(), compiler) - return {"$search": {**self.search_operator(), "index": index}} + index = self._get_query_index(self.get_search_fields(compiler, connection), compiler) + return {"$search": {**self.search_operator(compiler, connection), "index": index}} class SearchAutocomplete(SearchExpression): - def __init__(self, path, query, fuzzy=None, score=None): - self.path = path - self.query = query + def __init__(self, path, query, fuzzy=None, token_order=None, score=None): + self.path = F(path) if isinstance(path, str) else path + self.query = Value(query) if not hasattr(query, "resolve_expression") else query + if fuzzy is not None and not hasattr(fuzzy, "resolve_expression"): + fuzzy = Value(fuzzy, output_field=JSONField()) self.fuzzy = fuzzy + if token_order is not None and not hasattr(token_order, "resolve_expression"): + token_order = Value(token_order) + self.token_order = token_order self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + # Shall i implement resolve_something? I think I have to do + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, - "query": self.query, + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.as_mql(compiler, connection), } if self.score is not None: - params["score"] = self.score + params["score"] = self.score.as_mql(compiler, connection) if self.fuzzy is not None: - params["fuzzy"] = self.fuzzy + params["fuzzy"] = self.fuzzy.as_mql(compiler, connection) + if self.token_order is not None: + params["tokenOrder"] = self.token_order.as_mql(compiler, connection) return {"autocomplete": params} @@ -121,16 +137,16 @@ def __init__(self, path, value, score=None): self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, - "value": self.value, + "path": self.path.as_mql(compiler, connection, as_path=True), + "value": self.value.as_mql(compiler, connection, as_path=True), } if self.score is not None: - params["score"] = self.score + params["score"] = self.score.as_mql(compiler, connection, as_path=True) return {"equals": params} @@ -140,15 +156,15 @@ def __init__(self, path, score=None): self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, + "path": self.path.as_mql(compiler, connection, as_path=True), } if self.score is not None: - params["score"] = self.score + params["score"] = self.score.definitions return {"exists": params} @@ -159,16 +175,16 @@ def __init__(self, path, value, score=None): self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, - "value": self.value, + "path": self.path.as_mql(compiler, connection, as_path=True), + "value": self.value.as_mql(compiler, connection, as_path=True), } if self.score is not None: - params["score"] = self.score + params["score"] = self.score.definitions return {"in": params} @@ -181,20 +197,20 @@ def __init__(self, path, query, slop=None, synonyms=None, score=None): self.synonyms = synonyms super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, - "query": self.query, + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.as_mql(compiler, connection, as_path=True), } if self.score is not None: - params["score"] = self.score + params["score"] = self.score.as_mql(compiler, connection, as_path=True) if self.slop is not None: - params["slop"] = self.slop + params["slop"] = self.slop.as_mql(compiler, connection, as_path=True) if self.synonyms is not None: - params["synonyms"] = self.synonyms + params["synonyms"] = self.synonyms.as_mql(compiler, connection, as_path=True) return {"phrase": params} @@ -205,16 +221,16 @@ def __init__(self, path, query, score=None): self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { "defaultPath": self.path, - "query": self.query, + "query": self.query.as_mql(compiler, connection, as_path=True), } if self.score is not None: - params["score"] = self.score + params["score"] = self.score.definitions return {"queryString": params} @@ -228,15 +244,15 @@ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, + "path": self.path.as_mql(compiler, connection, as_path=True), } if self.score is not None: - params["score"] = self.score + params["score"] = self.score.definitions if self.lt is not None: params["lt"] = self.lt if self.lte is not None: @@ -256,16 +272,16 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None): self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, - "query": self.query, + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.as_mql(compiler, connection, as_path=True), } if self.score: - params["score"] = self.score + params["score"] = self.score.definitions if self.allow_analyzed_field is not None: params["allowAnalyzedField"] = self.allow_analyzed_field return {"regex": params} @@ -281,16 +297,16 @@ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, - "query": self.query, + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.as_mql(compiler, connection, as_path=True), } if self.score: - params["score"] = self.score + params["score"] = self.score.definitions if self.fuzzy is not None: params["fuzzy"] = self.fuzzy if self.match_criteria is not None: @@ -308,16 +324,16 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None): self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, - "query": self.query, + "path": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.as_mql(compiler, connection, as_path=True), } if self.score: - params["score"] = self.score + params["score"] = self.score.definitions if self.allow_analyzed_field is not None: params["allowAnalyzedField"] = self.allow_analyzed_field return {"wildcard": params} @@ -331,17 +347,17 @@ def __init__(self, path, relation, geometry, score=None): self.score = score super().__init__() - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, + "path": self.path.as_mql(compiler, connection, as_path=True), "relation": self.relation, "geometry": self.geometry, } if self.score: - params["score"] = self.score + params["score"] = self.score.definitions return {"geoShape": params} @@ -353,17 +369,17 @@ def __init__(self, path, kind, geo_object, score=None): self.score = score super().__init__() - def search_operator(self): + def search_operator(self, compiler, connection): params = { - "path": self.path, + "path": self.path.as_mql(compiler, connection, as_path=True), self.kind: self.geo_object, } if self.score: - params["score"] = self.score + params["score"] = self.score.definitions return {"geoWithin": params} - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} class SearchMoreLikeThis(SearchExpression): @@ -372,15 +388,15 @@ def __init__(self, documents, score=None): self.score = score super().__init__() - def search_operator(self): + def search_operator(self, compiler, connection): params = { "like": self.documents, } if self.score: - params["score"] = self.score + params["score"] = self.score.definitions return {"moreLikeThis": params} - def get_search_fields(self): + def get_search_fields(self, compiler, connection): needed_fields = set() for doc in self.documents: needed_fields.update(set(doc.keys())) @@ -404,13 +420,13 @@ def __init__( self.score = score self.minimum_should_match = minimum_should_match - def get_search_fields(self): + def get_search_fields(self, compiler, connection): fields = set() for clause in self.must + self.should + self.filter + self.must_not: fields.update(clause.get_search_fields()) return fields - def search_operator(self): + def search_operator(self, compiler, connection): params = {} if self.must: params["must"] = [clause.search_operator() for clause in self.must] @@ -491,8 +507,8 @@ def __or__(self, other): def __ror__(self, other): raise NotSupportedError("SearchVector cannot be combined") - def get_search_fields(self): - return {self.path} + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} def _get_query_index(self, fields, compiler): for search_indexes in compiler.collection.list_search_indexes(): @@ -507,7 +523,7 @@ def _get_query_index(self, fields, compiler): def as_mql(self, compiler, connection): params = { "index": self._get_query_index(self.get_search_fields(), compiler), - "path": self.path, + "path": self.path.as_mql(compiler, connection, as_path=True), "queryVector": self.query_vector, "limit": self.limit, } diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 4b49a4710..b7f562841 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -184,12 +184,16 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): previous = self key_transforms = [] while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs + if as_path: + mql = previous.as_mql(compiler, connection, as_path=True) + mql_path = ".".join(key_transforms) + return f"{mql}.{mql_path}" mql = previous.as_mql(compiler, connection) for key in key_transforms: mql = {"$getField": {"input": mql, "field": key}} diff --git a/tests/queries_/models.py b/tests/queries_/models.py index df60a6725..21af6fafd 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -1,6 +1,12 @@ from django.db import models -from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField +from django_mongodb_backend.fields import ( + ArrayField, + EmbeddedModelField, + ObjectIdAutoField, + ObjectIdField, +) +from django_mongodb_backend.models import EmbeddedModel class Author(models.Model): @@ -55,9 +61,14 @@ def __str__(self): return str(self.pk) +class Writer(EmbeddedModel): + name = models.CharField(max_length=10) + + class Article(models.Model): headline = models.CharField(max_length=100) number = models.IntegerField() body = models.TextField() location = models.JSONField(null=True) plot_embedding = ArrayField(models.FloatField(), size=3, null=True) + writer = EmbeddedModelField(Writer, null=True) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 4cd763747..7261be882 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -24,7 +24,7 @@ SearchWildcard, ) -from .models import Article +from .models import Article, Writer def _wait_for_assertion(timeout: float = 120, interval: float = 0.5) -> None: @@ -87,16 +87,6 @@ def _tear_down(self, model): wait_for_assertion = _wait_for_assertion(timeout=3) -class SearchTest(SearchUtilsMixin): - @classmethod - def setUpTestData(cls): - cls.create_search_index( - Article, - "equals_headline_index", - {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, - ) - - class SearchEqualsTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -132,13 +122,29 @@ def setUp(self): "minGrams": 3, "maxGrams": 5, "foldDiacritics": False, - } + }, + "writer": { + "type": "document", + "fields": { + "name": { + "type": "autocomplete", + "analyzer": "lucene.standard", + "tokenization": "edgeGram", + "minGrams": 3, + "maxGrams": 5, + "foldDiacritics": False, + } + }, + }, }, } }, ) self.article = Article.objects.create( - headline="crossing and something", number=2, body="river" + headline="crossing and something", + number=2, + body="river", + writer=Writer(name="Joselina A. Ramirez"), ) Article.objects.create(headline="Some random text", number=3, body="river") @@ -147,7 +153,20 @@ def tearDown(self): super().tearDown() def test_search_autocomplete(self): - qs = Article.objects.annotate(score=SearchAutocomplete(path="headline", query="crossing")) + qs = Article.objects.annotate( + score=SearchAutocomplete( + path="headline", + query="crossing", + token_order="sequential", # noqa: S106 + fuzzy={"maxEdits": 2}, + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + + def test_search_autocomplete_embedded_model(self): + qs = Article.objects.annotate( + score=SearchAutocomplete(path="writer__name", query="Joselina") + ) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) From 7ce99b5250fde360d019bcc97182ef757541db87 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 00:46:26 -0300 Subject: [PATCH 22/37] Adding source and set source. --- django_mongodb_backend/expressions/search.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 623cc40e2..f29f636be 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -76,6 +76,9 @@ def __repr__(self): def as_sql(self, compiler, connection): return "", [] + def get_source_expressions(self): + return [] + def _get_indexed_fields(self, mappings): for field, definition in mappings.get("fields", {}).items(): yield field @@ -112,6 +115,12 @@ def __init__(self, path, query, fuzzy=None, token_order=None, score=None): self.score = score super().__init__() + def get_source_expressions(self): + return [self.path, self.query, self.fuzzy, self.token_order] + + def set_source_expressions(self, exprs): + self.path, self.query, self.fuzzy, self.token_order = exprs + def get_search_fields(self, compiler, connection): # Shall i implement resolve_something? I think I have to do return {self.path.as_mql(compiler, connection, as_path=True)} From aaadb9dcdb993b19e1d6ca1b35f0e7adec758488 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 01:30:55 -0300 Subject: [PATCH 23/37] Edits. --- django_mongodb_backend/expressions/search.py | 237 ++++++++++++++----- 1 file changed, 178 insertions(+), 59 deletions(-) diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index f29f636be..327bff6f4 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -1,5 +1,5 @@ from django.db import NotSupportedError -from django.db.models import Expression, FloatField, JSONField +from django.db.models import Expression, FloatField from django.db.models.expressions import F, Value @@ -79,6 +79,14 @@ def as_sql(self, compiler, connection): def get_source_expressions(self): return [] + @staticmethod + def cast_as_value(value): + return Value(value) if not hasattr(value, "resolve_expression") else value + + @staticmethod + def cast_as_field(path): + return F(path) if isinstance(path, str) else path + def _get_indexed_fields(self, mappings): for field, definition in mappings.get("fields", {}).items(): yield field @@ -104,14 +112,10 @@ def as_mql(self, compiler, connection): class SearchAutocomplete(SearchExpression): def __init__(self, path, query, fuzzy=None, token_order=None, score=None): - self.path = F(path) if isinstance(path, str) else path - self.query = Value(query) if not hasattr(query, "resolve_expression") else query - if fuzzy is not None and not hasattr(fuzzy, "resolve_expression"): - fuzzy = Value(fuzzy, output_field=JSONField()) - self.fuzzy = fuzzy - if token_order is not None and not hasattr(token_order, "resolve_expression"): - token_order = Value(token_order) - self.token_order = token_order + self.path = self.cast_as_field(path) + self.query = self.cast_as_value(query) + self.fuzzy = self.cast_as_value(fuzzy) + self.token_order = self.cast_as_value(token_order) self.score = score super().__init__() @@ -122,7 +126,6 @@ def set_source_expressions(self, exprs): self.path, self.query, self.fuzzy, self.token_order = exprs def get_search_fields(self, compiler, connection): - # Shall i implement resolve_something? I think I have to do return {self.path.as_mql(compiler, connection, as_path=True)} def search_operator(self, compiler, connection): @@ -141,14 +144,20 @@ def search_operator(self, compiler, connection): class SearchEquals(SearchExpression): def __init__(self, path, value, score=None): - self.path = path - self.value = value + self.path = self.cast_as_field(path) + self.value = self.cast_as_value(value) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.value] + + def set_source_expressions(self, exprs): + self.path, self.value = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -161,13 +170,19 @@ def search_operator(self, compiler, connection): class SearchExists(SearchExpression): def __init__(self, path, score=None): - self.path = path + self.path = self.cast_as_field(path) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path] + + def set_source_expressions(self, exprs): + (self.path,) = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -179,14 +194,20 @@ def search_operator(self, compiler, connection): class SearchIn(SearchExpression): def __init__(self, path, value, score=None): - self.path = path - self.value = value + self.path = self.cast_as_field(path) + self.value = self.cast_as_value(value) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.value] + + def set_source_expressions(self, exprs): + self.path, self.value = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -199,16 +220,22 @@ def search_operator(self, compiler, connection): class SearchPhrase(SearchExpression): def __init__(self, path, query, slop=None, synonyms=None, score=None): - self.path = path - self.query = query + self.path = self.cast_as_field(path) + self.query = self.cast_as_value(query) + self.slop = self.cast_as_value(slop) + self.synonyms = self.cast_as_value(synonyms) self.score = score - self.slop = slop - self.synonyms = synonyms super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.query, self.slop, self.synonyms] + + def set_source_expressions(self, exprs): + self.path, self.query, self.score, self.slop, self.synonyms = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -225,14 +252,20 @@ def search_operator(self, compiler, connection): class SearchQueryString(SearchExpression): def __init__(self, path, query, score=None): - self.path = path - self.query = query + self.path = self.cast_as_field(path) + self.query = self.cast_as_value(query) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.query] + + def set_source_expressions(self, exprs): + self.path, self.query = exprs + def search_operator(self, compiler, connection): params = { "defaultPath": self.path, @@ -245,17 +278,23 @@ def search_operator(self, compiler, connection): class SearchRange(SearchExpression): def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): - self.path = path - self.lt = lt - self.lte = lte - self.gt = gt - self.gte = gte + self.path = self.cast_as_field(path) + self.lt = self.cast_as_value(lt) + self.lte = self.cast_as_value(lte) + self.gt = self.cast_as_value(gt) + self.gte = self.cast_as_value(gte) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.query, self.lt, self.lte, self.gt, self.gte] + + def set_source_expressions(self, exprs): + self.path, self.query, self.lt, self.lte, self.gt, self.gte = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -275,15 +314,21 @@ def search_operator(self, compiler, connection): class SearchRegex(SearchExpression): def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = path - self.query = query - self.allow_analyzed_field = allow_analyzed_field + self.path = self.cast_as_field(path) + self.query = self.cast_as_value(query) + self.allow_analyzed_field = self.cast_as_value(allow_analyzed_field) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.query, self.allow_analyzed_field] + + def set_source_expressions(self, exprs): + self.path, self.query, self.allow_analyzed_field = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -298,17 +343,23 @@ def search_operator(self, compiler, connection): class SearchText(SearchExpression): def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): - self.path = path - self.query = query - self.fuzzy = fuzzy - self.match_criteria = match_criteria - self.synonyms = synonyms + self.path = self.cast_as_field(path) + self.query = self.cast_as_value(query) + self.fuzzy = self.cast_as_value(fuzzy) + self.match_criteria = self.cast_as_value(match_criteria) + self.synonyms = self.cast_as_value(synonyms) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.query, self.fuzzy, self.match_criteria, self.synonyms] + + def set_source_expressions(self, exprs): + self.path, self.query, self.fuzzy, self.match_criteria, self.synonyms = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -327,15 +378,21 @@ def search_operator(self, compiler, connection): class SearchWildcard(SearchExpression): def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = path - self.query = query - self.allow_analyzed_field = allow_analyzed_field + self.path = self.cast_as_field(path) + self.query = self.cast_as_value(query) + self.allow_analyzed_field = self.cast_as_value(allow_analyzed_field) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.query, self.allow_analyzed_field] + + def set_source_expressions(self, exprs): + self.path, self.query, self.allow_analyzed_field = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -350,15 +407,21 @@ def search_operator(self, compiler, connection): class SearchGeoShape(SearchExpression): def __init__(self, path, relation, geometry, score=None): - self.path = path - self.relation = relation - self.geometry = geometry + self.path = self.cast_as_field(path) + self.relation = self.cast_as_value(relation) + self.geometry = self.cast_as_value(geometry) self.score = score super().__init__() def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [self.path, self.relation, self.geometry] + + def set_source_expressions(self, exprs): + self.path, self.relation, self.geometry = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -372,12 +435,21 @@ def search_operator(self, compiler, connection): class SearchGeoWithin(SearchExpression): def __init__(self, path, kind, geo_object, score=None): - self.path = path - self.kind = kind - self.geo_object = geo_object + self.path = self.cast_as_field(path) + self.kind = self.cast_as_value(kind) + self.geo_object = self.cast_as_value(geo_object) self.score = score super().__init__() + def get_search_fields(self, compiler, connection): + return {self.path.as_mql(compiler, connection, as_path=True)} + + def get_source_expressions(self): + return [self.path, self.kind, self.geo_object] + + def set_source_expressions(self, exprs): + self.path, self.kind, self.geo_object = exprs + def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), @@ -387,16 +459,19 @@ def search_operator(self, compiler, connection): params["score"] = self.score.definitions return {"geoWithin": params} - def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} - class SearchMoreLikeThis(SearchExpression): def __init__(self, documents, score=None): - self.documents = documents + self.documents = self.cast_as_value(documents) self.score = score super().__init__() + def get_source_expressions(self): + return [self.documents] + + def set_source_expressions(self, exprs): + (self.documents,) = exprs + def search_operator(self, compiler, connection): params = { "like": self.documents, @@ -435,6 +510,25 @@ def get_search_fields(self, compiler, connection): fields.update(clause.get_search_fields()) return fields + def resolve_expression( + self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False + ): + c = self.copy() + c.is_summary = summarize + c.must = [ + expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.must + ] + c.must_not = [ + expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.must_not + ] + c.should = [ + expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.should + ] + c.filter = [ + expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.filter + ] + return c + def search_operator(self, compiler, connection): params = {} if self.must: @@ -447,7 +541,6 @@ def search_operator(self, compiler, connection): params["filter"] = [clause.search_operator() for clause in self.filter] if self.minimum_should_match is not None: params["minimumShouldMatch"] = self.minimum_should_match - return {"compound": params} def negate(self): @@ -460,6 +553,12 @@ def __init__(self, lhs, operator, rhs): self.operator = operator self.rhs = rhs + def get_source_expressions(self): + return [self.lhs, self.rhs] + + def set_source_expressions(self, exprs): + self.lhs, self.rhs = exprs + @staticmethod def resolve(node, negated=False): if node is None: @@ -493,12 +592,12 @@ def __init__( exact=None, filter=None, ): - self.path = path - self.query_vector = query_vector - self.limit = limit - self.num_candidates = num_candidates - self.exact = exact - self.filter = filter + self.path = self.cast_as_field(path) + self.query_vector = self.cast_as_value(query_vector) + self.limit = self.cast_as_value(limit) + self.num_candidates = self.cast_as_value(num_candidates) + self.exact = self.cast_as_value(exact) + self.filter = self.cast_as_value(filter) super().__init__() def __invert__(self): @@ -519,6 +618,26 @@ def __ror__(self, other): def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} + def get_source_expressions(self): + return [ + self.path, + self.query_vector, + self.limit, + self.num_candidates, + self.exact, + self.filter, + ] + + def set_source_expressions(self, exprs): + ( + self.path, + self.query_vector, + self.limit, + self.num_candidates, + self.exact, + self.filter, + ) = exprs + def _get_query_index(self, fields, compiler): for search_indexes in compiler.collection.list_search_indexes(): if search_indexes["type"] == "vectorSearch": @@ -533,15 +652,15 @@ def as_mql(self, compiler, connection): params = { "index": self._get_query_index(self.get_search_fields(), compiler), "path": self.path.as_mql(compiler, connection, as_path=True), - "queryVector": self.query_vector, - "limit": self.limit, + "queryVector": self.query_vector.as_mql(compiler, connection), + "limit": self.limit.as_mql(compiler, connection), } if self.num_candidates is not None: - params["numCandidates"] = self.num_candidates + params["numCandidates"] = self.num_candidates.as_mql(compiler, connection) if self.exact is not None: - params["exact"] = self.exact + params["exact"] = self.exact.as_mql(compiler, connection) if self.filter is not None: - params["filter"] = self.filter + params["filter"] = self.filter.as_mql(compiler, connection) return {"$vectorSearch": params} From aaeea0d097eb82d8280d0973e378c3dfd0b58806 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 12:38:32 -0300 Subject: [PATCH 24/37] Edits --- django_mongodb_backend/expressions/search.py | 32 ++++++++++---------- tests/queries_/test_search.py | 8 ++--- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 327bff6f4..4468d3cb5 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -161,10 +161,10 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "value": self.value.as_mql(compiler, connection, as_path=True), + "value": self.value.as_mql(compiler, connection), } if self.score is not None: - params["score"] = self.score.as_mql(compiler, connection, as_path=True) + params["score"] = self.score.as_mql(compiler, connection) return {"equals": params} @@ -211,7 +211,7 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "value": self.value.as_mql(compiler, connection, as_path=True), + "value": self.value.as_mql(compiler, connection), } if self.score is not None: params["score"] = self.score.definitions @@ -234,19 +234,19 @@ def get_source_expressions(self): return [self.path, self.query, self.slop, self.synonyms] def set_source_expressions(self, exprs): - self.path, self.query, self.score, self.slop, self.synonyms = exprs + self.path, self.query, self.slop, self.synonyms = exprs def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.as_mql(compiler, connection, as_path=True), + "query": self.query.as_mql(compiler, connection), } if self.score is not None: - params["score"] = self.score.as_mql(compiler, connection, as_path=True) + params["score"] = self.score.as_mql(compiler, connection) if self.slop is not None: - params["slop"] = self.slop.as_mql(compiler, connection, as_path=True) + params["slop"] = self.slop.as_mql(compiler, connection) if self.synonyms is not None: - params["synonyms"] = self.synonyms.as_mql(compiler, connection, as_path=True) + params["synonyms"] = self.synonyms.as_mql(compiler, connection) return {"phrase": params} @@ -268,8 +268,8 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { - "defaultPath": self.path, - "query": self.query.as_mql(compiler, connection, as_path=True), + "defaultPath": self.path.as_mql(compiler, connection, as_path=True), + "query": self.query.as_mql(compiler, connection), } if self.score is not None: params["score"] = self.score.definitions @@ -290,10 +290,10 @@ def get_search_fields(self, compiler, connection): return {self.path.as_mql(compiler, connection, as_path=True)} def get_source_expressions(self): - return [self.path, self.query, self.lt, self.lte, self.gt, self.gte] + return [self.path, self.lt, self.lte, self.gt, self.gte] def set_source_expressions(self, exprs): - self.path, self.query, self.lt, self.lte, self.gt, self.gte = exprs + self.path, self.lt, self.lte, self.gt, self.gte = exprs def search_operator(self, compiler, connection): params = { @@ -302,13 +302,13 @@ def search_operator(self, compiler, connection): if self.score is not None: params["score"] = self.score.definitions if self.lt is not None: - params["lt"] = self.lt + params["lt"] = self.lt.as_mql(compiler, connection) if self.lte is not None: - params["lte"] = self.lte + params["lte"] = self.lte.as_mql(compiler, connection) if self.gt is not None: - params["gt"] = self.gt + params["gt"] = self.gt.as_mql(compiler, connection) if self.gte is not None: - params["gte"] = self.gte + params["gte"] = self.gte.as_mql(compiler, connection) return {"range": params} diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 7261be882..03436d1c7 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -97,14 +97,14 @@ def setUp(self): self.article = Article.objects.create(headline="cross", number=1, body="body") Article.objects.create(headline="other thing", number=2, body="body") - def test_search_equals(self): - qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")) - self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) - def tearDown(self): self._tear_down(Article) super().tearDown() + def test_search_equals(self): + qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + class SearchAutocompleteTest(SearchUtilsMixin): def setUp(self): From c351f6b18d5db92bfc83ec692e1f9bb02ce5c511 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 13:29:07 -0300 Subject: [PATCH 25/37] Resolve value as direct value. --- django_mongodb_backend/expressions/search.py | 96 +++++++++++--------- 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 4468d3cb5..47c46bd65 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -81,6 +81,8 @@ def get_source_expressions(self): @staticmethod def cast_as_value(value): + if value is None: + return None return Value(value) if not hasattr(value, "resolve_expression") else value @staticmethod @@ -131,14 +133,14 @@ def get_search_fields(self, compiler, connection): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.as_mql(compiler, connection), + "query": self.query.value, } if self.score is not None: params["score"] = self.score.as_mql(compiler, connection) if self.fuzzy is not None: - params["fuzzy"] = self.fuzzy.as_mql(compiler, connection) + params["fuzzy"] = self.fuzzy.value if self.token_order is not None: - params["tokenOrder"] = self.token_order.as_mql(compiler, connection) + params["tokenOrder"] = self.token_order.value return {"autocomplete": params} @@ -161,7 +163,7 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "value": self.value.as_mql(compiler, connection), + "value": self.value.value, } if self.score is not None: params["score"] = self.score.as_mql(compiler, connection) @@ -211,10 +213,10 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "value": self.value.as_mql(compiler, connection), + "value": self.value.value, } if self.score is not None: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) return {"in": params} @@ -239,14 +241,14 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.as_mql(compiler, connection), + "query": self.query.value, } if self.score is not None: params["score"] = self.score.as_mql(compiler, connection) if self.slop is not None: - params["slop"] = self.slop.as_mql(compiler, connection) + params["slop"] = self.slop.value if self.synonyms is not None: - params["synonyms"] = self.synonyms.as_mql(compiler, connection) + params["synonyms"] = self.synonyms.value return {"phrase": params} @@ -269,10 +271,10 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "defaultPath": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.as_mql(compiler, connection), + "query": self.query.value, } if self.score is not None: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) return {"queryString": params} @@ -300,15 +302,15 @@ def search_operator(self, compiler, connection): "path": self.path.as_mql(compiler, connection, as_path=True), } if self.score is not None: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) if self.lt is not None: - params["lt"] = self.lt.as_mql(compiler, connection) + params["lt"] = self.lt.value if self.lte is not None: - params["lte"] = self.lte.as_mql(compiler, connection) + params["lte"] = self.lte.value if self.gt is not None: - params["gt"] = self.gt.as_mql(compiler, connection) + params["gt"] = self.gt.value if self.gte is not None: - params["gte"] = self.gte.as_mql(compiler, connection) + params["gte"] = self.gte.value return {"range": params} @@ -332,12 +334,12 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.as_mql(compiler, connection, as_path=True), + "query": self.query.value, } if self.score: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) if self.allow_analyzed_field is not None: - params["allowAnalyzedField"] = self.allow_analyzed_field + params["allowAnalyzedField"] = self.allow_analyzed_field.value return {"regex": params} @@ -363,16 +365,16 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.as_mql(compiler, connection, as_path=True), + "query": self.query.value, } if self.score: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) if self.fuzzy is not None: - params["fuzzy"] = self.fuzzy + params["fuzzy"] = self.fuzzy.value if self.match_criteria is not None: - params["matchCriteria"] = self.match_criteria + params["matchCriteria"] = self.match_criteria.value if self.synonyms is not None: - params["synonyms"] = self.synonyms + params["synonyms"] = self.synonyms.value return {"text": params} @@ -396,12 +398,12 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "query": self.query.as_mql(compiler, connection, as_path=True), + "query": self.query.value, } if self.score: - params["score"] = self.score.definitions + params["score"] = self.score.query.as_mql(compiler, connection) if self.allow_analyzed_field is not None: - params["allowAnalyzedField"] = self.allow_analyzed_field + params["allowAnalyzedField"] = self.allow_analyzed_field.value return {"wildcard": params} @@ -425,11 +427,11 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - "relation": self.relation, - "geometry": self.geometry, + "relation": self.relation.value, + "geometry": self.geometry.value, } if self.score: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) return {"geoShape": params} @@ -453,10 +455,10 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { "path": self.path.as_mql(compiler, connection, as_path=True), - self.kind: self.geo_object, + self.kind.value: self.geo_object.value, } if self.score: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) return {"geoWithin": params} @@ -474,10 +476,10 @@ def set_source_expressions(self, exprs): def search_operator(self, compiler, connection): params = { - "like": self.documents, + "like": self.documents.as_mql(compiler, connection), } if self.score: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) return {"moreLikeThis": params} def get_search_fields(self, compiler, connection): @@ -507,7 +509,7 @@ def __init__( def get_search_fields(self, compiler, connection): fields = set() for clause in self.must + self.should + self.filter + self.must_not: - fields.update(clause.get_search_fields()) + fields.update(clause.get_search_fields(compiler, connection)) return fields def resolve_expression( @@ -532,13 +534,19 @@ def resolve_expression( def search_operator(self, compiler, connection): params = {} if self.must: - params["must"] = [clause.search_operator() for clause in self.must] + params["must"] = [clause.search_operator(compiler, connection) for clause in self.must] if self.must_not: - params["mustNot"] = [clause.search_operator() for clause in self.must_not] + params["mustNot"] = [ + clause.search_operator(compiler, connection) for clause in self.must_not + ] if self.should: - params["should"] = [clause.search_operator() for clause in self.should] + params["should"] = [ + clause.search_operator(compiler, connection) for clause in self.should + ] if self.filter: - params["filter"] = [clause.search_operator() for clause in self.filter] + params["filter"] = [ + clause.search_operator(compiler, connection) for clause in self.filter + ] if self.minimum_should_match is not None: params["minimumShouldMatch"] = self.minimum_should_match return {"compound": params} @@ -650,15 +658,15 @@ def _get_query_index(self, fields, compiler): def as_mql(self, compiler, connection): params = { - "index": self._get_query_index(self.get_search_fields(), compiler), + "index": self._get_query_index(self.get_search_fields(compiler, connection), compiler), "path": self.path.as_mql(compiler, connection, as_path=True), - "queryVector": self.query_vector.as_mql(compiler, connection), - "limit": self.limit.as_mql(compiler, connection), + "queryVector": self.query_vector.value, + "limit": self.limit.value, } if self.num_candidates is not None: - params["numCandidates"] = self.num_candidates.as_mql(compiler, connection) + params["numCandidates"] = self.num_candidates.value if self.exact is not None: - params["exact"] = self.exact.as_mql(compiler, connection) + params["exact"] = self.exact.value if self.filter is not None: params["filter"] = self.filter.as_mql(compiler, connection) return {"$vectorSearch": params} From e6660f9e76eb9ffc151e6c22ee4646269a9e7e01 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 14:01:32 -0300 Subject: [PATCH 26/37] Add vibe docstrings to MongoDB Atlas search expressions. --- django_mongodb_backend/expressions/search.py | 329 +++++++++++++++++++ 1 file changed, 329 insertions(+) diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 47c46bd65..f7a70626b 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -62,6 +62,14 @@ def __ror__(self, other): class SearchExpression(SearchCombinable, Expression): + """Base expression node for MongoDB Atlas **$search** stages. + + This class bridges Django's ``Expression`` API with the MongoDB Atlas + Search engine. Subclasses produce the operator document placed under + **$search** and expose the stage to queryset methods such as + ``annotate()``, ``filter()``, or ``order_by()``. + """ + output_field = FloatField() def __str__(self): @@ -113,6 +121,29 @@ def as_mql(self, compiler, connection): class SearchAutocomplete(SearchExpression): + """ + Atlas Search expression that matches input using the **autocomplete** operator. + + This expression enables autocomplete behavior by querying against a field + indexed as `"type": "autocomplete"` in MongoDB Atlas. It can be used in + `filter()`, `annotate()` or any context that accepts a Django expression. + + Example: + SearchAutocomplete("title", "harry", fuzzy={"maxEdits": 1}) + + Args: + path: The document path to search (as string or expression). + query: The input string to autocomplete. + fuzzy: Optional dictionary of fuzzy matching parameters. + token_order: Optional value for `"tokenOrder"`; controls sequential vs. + any-order token matching. + score: Optional expression to adjust score relevance (e.g., `{"boost": {"value": 5}}`). + + Notes: + * Requires an Atlas Search index with `autocomplete` mappings. + * The operator is injected under the `$search` stage in the aggregation pipeline. + """ + def __init__(self, path, query, fuzzy=None, token_order=None, score=None): self.path = self.cast_as_field(path) self.query = self.cast_as_value(query) @@ -145,6 +176,26 @@ def search_operator(self, compiler, connection): class SearchEquals(SearchExpression): + """ + Atlas Search expression that matches documents with a field equal to the given value. + + This expression uses the **equals** operator to perform exact matches + on fields indexed in a MongoDB Atlas Search index. + + Example: + SearchEquals("category", "fiction") + + Args: + path: The document path to compare (as string or expression). + value: The exact value to match against. + score: Optional expression to modify the relevance score. + + Notes: + * The field must be indexed with a supported type for `equals`. + * Supports numeric, string, boolean, and date values. + * Score boosting can be applied using the `score` parameter. + """ + def __init__(self, path, value, score=None): self.path = self.cast_as_field(path) self.value = self.cast_as_value(value) @@ -171,6 +222,25 @@ def search_operator(self, compiler, connection): class SearchExists(SearchExpression): + """ + Atlas Search expression that matches documents where a field exists. + + This expression uses the **exists** operator to check whether a given + path is present in the document. Useful for filtering documents that + include (or exclude) optional fields. + + Example: + SearchExists("metadata__author") + + Args: + path: The document path to check (as string or expression). + score: Optional expression to modify the relevance score. + + Notes: + * The target field must be mapped in the Atlas Search index. + * This does not test for null—only for presence. + """ + def __init__(self, path, score=None): self.path = self.cast_as_field(path) self.score = score @@ -221,6 +291,28 @@ def search_operator(self, compiler, connection): class SearchPhrase(SearchExpression): + """ + Atlas Search expression that matches a phrase in the specified field. + + This expression uses the **phrase** operator to search for exact or near-exact + sequences of terms. It supports optional slop (word distance) and synonym sets. + + Example: + SearchPhrase("description__text", "climate change", slop=2) + + Args: + path: The document path to search (as string or expression). + query: The phrase to match as a single string or list of terms. + slop: Optional maximum word distance allowed between phrase terms. + synonyms: Optional name of a synonym mapping defined in the Atlas index. + score: Optional expression to modify the relevance score. + + Notes: + * The field must be mapped as `"type": "string"` with appropriate analyzers. + * Slop allows flexibility in word positioning, like `"quick brown fox"` + matching `"quick fox"` if `slop=1`. + """ + def __init__(self, path, query, slop=None, synonyms=None, score=None): self.path = self.cast_as_field(path) self.query = self.cast_as_value(query) @@ -253,6 +345,26 @@ def search_operator(self, compiler, connection): class SearchQueryString(SearchExpression): + """ + Atlas Search expression that matches using a Lucene-style query string. + + This expression uses the **queryString** operator to parse and execute + full-text queries written in a simplified Lucene syntax. It supports + advanced constructs like boolean operators, wildcards, and field-specific terms. + + Example: + SearchQueryString("content__text", "django AND (search OR query)") + + Args: + path: The document path to query (as string or expression). + query: The Lucene-style query string. + score: Optional expression to modify the relevance score. + + Notes: + * The query string syntax must conform to Atlas Search rules. + * This operator is powerful but can be harder to validate or sanitize. + """ + def __init__(self, path, query, score=None): self.path = self.cast_as_field(path) self.query = self.cast_as_value(query) @@ -279,6 +391,28 @@ def search_operator(self, compiler, connection): class SearchRange(SearchExpression): + """ + Atlas Search expression that filters documents within a range of values. + + This expression uses the **range** operator to match numeric, date, or + other comparable fields based on upper and/or lower bounds. + + Example: + SearchRange("published__year", gte=2000, lt=2020) + + Args: + path: The document path to filter (as string or expression). + lt: Optional exclusive upper bound (`<`). + lte: Optional inclusive upper bound (`<=`). + gt: Optional exclusive lower bound (`>`). + gte: Optional inclusive lower bound (`>=`). + score: Optional expression to modify the relevance score. + + Notes: + * At least one of `lt`, `lte`, `gt`, or `gte` must be provided. + * The field must be mapped in the Atlas Search index as a comparable type. + """ + def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): self.path = self.cast_as_field(path) self.lt = self.cast_as_value(lt) @@ -315,6 +449,27 @@ def search_operator(self, compiler, connection): class SearchRegex(SearchExpression): + """ + Atlas Search expression that matches strings using a regular expression. + + This expression uses the **regex** operator to apply a regular expression + against the contents of a specified field. + + Example: + SearchRegex("username", r"^admin_") + + Args: + path: The document path to match (as string or expression). + query: The regular expression pattern to apply. + allow_analyzed_field: Whether to allow matching against analyzed fields (default is False). + score: Optional expression to modify the relevance score. + + Notes: + * Regular expressions must follow JavaScript regex syntax. + * By default, the field must be mapped as `"analyzer": "keyword"` + unless `allow_analyzed_field=True`. + """ + def __init__(self, path, query, allow_analyzed_field=None, score=None): self.path = self.cast_as_field(path) self.query = self.cast_as_value(query) @@ -344,6 +499,28 @@ def search_operator(self, compiler, connection): class SearchText(SearchExpression): + """ + Atlas Search expression that performs full-text search using the **text** operator. + + This expression matches terms in a specified field with options for + fuzzy matching, match criteria, and synonyms. + + Example: + SearchText("description__content", "mongodb", fuzzy={"maxEdits": 1}, match_criteria="all") + + Args: + path: The document path to search (as string or expression). + query: The search term or phrase. + fuzzy: Optional dictionary to configure fuzzy matching parameters. + match_criteria: Optional criteria for term matching (e.g., "all" or "any"). + synonyms: Optional name of a synonym mapping defined in the Atlas index. + score: Optional expression to adjust relevance scoring. + + Notes: + * The target field must be indexed for full-text search in Atlas. + * Fuzzy matching helps match terms with minor typos or variations. + """ + def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): self.path = self.cast_as_field(path) self.query = self.cast_as_value(query) @@ -379,6 +556,28 @@ def search_operator(self, compiler, connection): class SearchWildcard(SearchExpression): + """ + Atlas Search expression that matches strings using wildcard patterns. + + This expression uses the **wildcard** operator to search for terms + matching a pattern with `*` and `?` wildcards. + + Example: + SearchWildcard("filename", "report_202?_final*") + + Args: + path: The document path to search (as string or expression). + query: The wildcard pattern to match. + allow_analyzed_field: Whether to allow matching against analyzed fields (default is False). + score: Optional expression to modify the relevance score. + + Notes: + * Wildcard patterns follow standard syntax, where `*` matches any sequence of characters + and `?` matches a single character. + * By default, the field should be keyword or unanalyzed + unless `allow_analyzed_field=True`. + """ + def __init__(self, path, query, allow_analyzed_field=None, score=None): self.path = self.cast_as_field(path) self.query = self.cast_as_value(query) @@ -408,6 +607,26 @@ def search_operator(self, compiler, connection): class SearchGeoShape(SearchExpression): + """ + Atlas Search expression that filters documents by spatial relationship with a geometry. + + This expression uses the **geoShape** operator to match documents where + a geo field relates to a specified geometry by a spatial relation. + + Example: + SearchGeoShape("location", "within", {"type": "Polygon", "coordinates": [...]}) + + Args: + path: The document path to the geo field (as string or expression). + relation: The spatial relation to test (e.g., "within", "intersects", "disjoint"). + geometry: The GeoJSON geometry to compare against. + score: Optional expression to modify the relevance score. + + Notes: + * The field must be indexed as a geo shape type in Atlas Search. + * Geometry must conform to GeoJSON specification. + """ + def __init__(self, path, relation, geometry, score=None): self.path = self.cast_as_field(path) self.relation = self.cast_as_value(relation) @@ -436,6 +655,27 @@ def search_operator(self, compiler, connection): class SearchGeoWithin(SearchExpression): + """ + Atlas Search expression that filters documents with geo fields + contained within a specified shape. + + This expression uses the **geoWithin** operator to match documents where + the geo field lies entirely within the given geometry. + + Example: + SearchGeoWithin("location", "Polygon", {"type": "Polygon", "coordinates": [...]}) + + Args: + path: The document path to the geo field (as string or expression). + kind: The GeoJSON geometry type (e.g., "Polygon", "MultiPolygon"). + geo_object: The GeoJSON geometry defining the boundary. + score: Optional expression to adjust the relevance score. + + Notes: + * The geo field must be indexed appropriately in the Atlas Search index. + * The geometry must follow GeoJSON format. + """ + def __init__(self, path, kind, geo_object, score=None): self.path = self.cast_as_field(path) self.kind = self.cast_as_value(kind) @@ -463,6 +703,24 @@ def search_operator(self, compiler, connection): class SearchMoreLikeThis(SearchExpression): + """ + Atlas Search expression that finds documents similar to given examples. + + This expression uses the **moreLikeThis** operator to search for documents + that resemble the specified sample documents. + + Example: + SearchMoreLikeThis([{"_id": ObjectId("...")}, {"title": "Example"}]) + + Args: + documents: A list of example documents or expressions to find similar documents. + score: Optional expression to modify the relevance scoring. + + Notes: + * The documents should be representative examples to base similarity on. + * Supports various field types depending on the Atlas Search configuration. + """ + def __init__(self, documents, score=None): self.documents = self.cast_as_value(documents) self.score = score @@ -490,6 +748,34 @@ def get_search_fields(self, compiler, connection): class CompoundExpression(SearchExpression): + """ + Compound expression that combines multiple search clauses using boolean logic. + + This expression corresponds to the **compound** operator in MongoDB Atlas Search, + allowing fine-grained control by combining multiple sub-expressions with + `must`, `must_not`, `should`, and `filter` clauses. + + Example: + CompoundExpression( + must=[expr1, expr2], + must_not=[expr3], + should=[expr4], + minimum_should_match=1 + ) + + Args: + must: List of expressions that **must** match. + must_not: List of expressions that **must not** match. + should: List of expressions that **should** match (optional relevance boost). + filter: List of expressions to filter results without affecting relevance. + score: Optional expression to adjust scoring. + minimum_should_match: Minimum number of `should` clauses that must match. + + Notes: + * This is the most flexible way to build complex Atlas Search queries. + * Supports nesting of expressions to any depth. + """ + def __init__( self, must=None, @@ -556,6 +842,26 @@ def negate(self): class CombinedSearchExpression(SearchExpression): + """ + Combines two search expressions with a logical operator. + + This expression allows combining two Atlas Search expressions + (left-hand side and right-hand side) using a boolean operator + such as `and`, `or`, or `not`. + + Example: + CombinedSearchExpression(expr1, "and", expr2) + + Args: + lhs: The left-hand search expression. + operator: The boolean operator as a string (e.g., "and", "or", "not"). + rhs: The right-hand search expression. + + Notes: + * The operator must be supported by MongoDB Atlas Search boolean logic. + * This class enables building complex nested search queries. + """ + def __init__(self, lhs, operator, rhs): self.lhs = lhs self.operator = operator @@ -591,6 +897,29 @@ def as_mql(self, compiler, connection): class SearchVector(SearchExpression): + """ + Atlas Search expression that performs vector similarity search on embedded vectors. + + This expression uses the **knnBeta** operator to find documents whose vector + embeddings are most similar to a given query vector. + + Example: + SearchVector("embedding", [0.1, 0.2, 0.3], limit=10, num_candidates=100) + + Args: + path: The document path to the vector field (as string or expression). + query_vector: The query vector to compare against. + limit: Maximum number of matching documents to return. + num_candidates: Optional number of candidates to consider during search. + exact: Optional flag to enforce exact matching. + filter: Optional filter expression to narrow candidate documents. + + Notes: + * The vector field must be indexed as a vector type in Atlas Search. + * Parameters like `num_candidates` and `exact` control search + performance and accuracy trade-offs. + """ + def __init__( self, path, From 318b01055201abca1e59f1bf3607a8b3822ff12a Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 14:20:11 -0300 Subject: [PATCH 27/37] Fix invalid operation. --- django_mongodb_backend/creation.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/django_mongodb_backend/creation.py b/django_mongodb_backend/creation.py index 4057e786f..b787ad06c 100644 --- a/django_mongodb_backend/creation.py +++ b/django_mongodb_backend/creation.py @@ -22,9 +22,10 @@ def _destroy_test_db(self, test_database_name, verbosity): for collection in self.connection.introspection.table_names(): if not collection.startswith("system."): - db_collection = self.connection.database.get_collection(collection) - for search_indexes in db_collection.list_search_indexes(): - db_collection.drop_search_index(search_indexes["name"]) + if self.connection.features.supports_atlas_search: + db_collection = self.connection.database.get_collection(collection) + for search_indexes in db_collection.list_search_indexes(): + db_collection.drop_search_index(search_indexes["name"]) self.connection.database.drop_collection(collection) def create_test_db(self, *args, **kwargs): From 62fb3eaeee451bbdb8d3cd514853337550cd084e Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 14:20:32 -0300 Subject: [PATCH 28/37] Refactor utils function --- django_mongodb_backend/expressions/search.py | 113 ++++++++++--------- 1 file changed, 58 insertions(+), 55 deletions(-) diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index f7a70626b..f826a17c6 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -3,6 +3,16 @@ from django.db.models.expressions import F, Value +def cast_as_value(value): + if value is None: + return None + return Value(value) if not hasattr(value, "resolve_expression") else value + + +def cast_as_field(path): + return F(path) if isinstance(path, str) else path + + class Operator: AND = "AND" OR = "OR" @@ -87,16 +97,6 @@ def as_sql(self, compiler, connection): def get_source_expressions(self): return [] - @staticmethod - def cast_as_value(value): - if value is None: - return None - return Value(value) if not hasattr(value, "resolve_expression") else value - - @staticmethod - def cast_as_field(path): - return F(path) if isinstance(path, str) else path - def _get_indexed_fields(self, mappings): for field, definition in mappings.get("fields", {}).items(): yield field @@ -145,10 +145,10 @@ class SearchAutocomplete(SearchExpression): """ def __init__(self, path, query, fuzzy=None, token_order=None, score=None): - self.path = self.cast_as_field(path) - self.query = self.cast_as_value(query) - self.fuzzy = self.cast_as_value(fuzzy) - self.token_order = self.cast_as_value(token_order) + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.fuzzy = cast_as_value(fuzzy) + self.token_order = cast_as_value(token_order) self.score = score super().__init__() @@ -197,8 +197,8 @@ class SearchEquals(SearchExpression): """ def __init__(self, path, value, score=None): - self.path = self.cast_as_field(path) - self.value = self.cast_as_value(value) + self.path = cast_as_field(path) + self.value = cast_as_value(value) self.score = score super().__init__() @@ -242,7 +242,7 @@ class SearchExists(SearchExpression): """ def __init__(self, path, score=None): - self.path = self.cast_as_field(path) + self.path = cast_as_field(path) self.score = score super().__init__() @@ -266,8 +266,8 @@ def search_operator(self, compiler, connection): class SearchIn(SearchExpression): def __init__(self, path, value, score=None): - self.path = self.cast_as_field(path) - self.value = self.cast_as_value(value) + self.path = cast_as_field(path) + self.value = cast_as_value(value) self.score = score super().__init__() @@ -314,10 +314,10 @@ class SearchPhrase(SearchExpression): """ def __init__(self, path, query, slop=None, synonyms=None, score=None): - self.path = self.cast_as_field(path) - self.query = self.cast_as_value(query) - self.slop = self.cast_as_value(slop) - self.synonyms = self.cast_as_value(synonyms) + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.slop = cast_as_value(slop) + self.synonyms = cast_as_value(synonyms) self.score = score super().__init__() @@ -366,8 +366,8 @@ class SearchQueryString(SearchExpression): """ def __init__(self, path, query, score=None): - self.path = self.cast_as_field(path) - self.query = self.cast_as_value(query) + self.path = cast_as_field(path) + self.query = cast_as_value(query) self.score = score super().__init__() @@ -414,11 +414,11 @@ class SearchRange(SearchExpression): """ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): - self.path = self.cast_as_field(path) - self.lt = self.cast_as_value(lt) - self.lte = self.cast_as_value(lte) - self.gt = self.cast_as_value(gt) - self.gte = self.cast_as_value(gte) + self.path = cast_as_field(path) + self.lt = cast_as_value(lt) + self.lte = cast_as_value(lte) + self.gt = cast_as_value(gt) + self.gte = cast_as_value(gte) self.score = score super().__init__() @@ -471,9 +471,9 @@ class SearchRegex(SearchExpression): """ def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = self.cast_as_field(path) - self.query = self.cast_as_value(query) - self.allow_analyzed_field = self.cast_as_value(allow_analyzed_field) + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.allow_analyzed_field = cast_as_value(allow_analyzed_field) self.score = score super().__init__() @@ -522,11 +522,11 @@ class SearchText(SearchExpression): """ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): - self.path = self.cast_as_field(path) - self.query = self.cast_as_value(query) - self.fuzzy = self.cast_as_value(fuzzy) - self.match_criteria = self.cast_as_value(match_criteria) - self.synonyms = self.cast_as_value(synonyms) + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.fuzzy = cast_as_value(fuzzy) + self.match_criteria = cast_as_value(match_criteria) + self.synonyms = cast_as_value(synonyms) self.score = score super().__init__() @@ -579,9 +579,9 @@ class SearchWildcard(SearchExpression): """ def __init__(self, path, query, allow_analyzed_field=None, score=None): - self.path = self.cast_as_field(path) - self.query = self.cast_as_value(query) - self.allow_analyzed_field = self.cast_as_value(allow_analyzed_field) + self.path = cast_as_field(path) + self.query = cast_as_value(query) + self.allow_analyzed_field = cast_as_value(allow_analyzed_field) self.score = score super().__init__() @@ -628,9 +628,9 @@ class SearchGeoShape(SearchExpression): """ def __init__(self, path, relation, geometry, score=None): - self.path = self.cast_as_field(path) - self.relation = self.cast_as_value(relation) - self.geometry = self.cast_as_value(geometry) + self.path = cast_as_field(path) + self.relation = cast_as_value(relation) + self.geometry = cast_as_value(geometry) self.score = score super().__init__() @@ -677,9 +677,9 @@ class SearchGeoWithin(SearchExpression): """ def __init__(self, path, kind, geo_object, score=None): - self.path = self.cast_as_field(path) - self.kind = self.cast_as_value(kind) - self.geo_object = self.cast_as_value(geo_object) + self.path = cast_as_field(path) + self.kind = cast_as_value(kind) + self.geo_object = cast_as_value(geo_object) self.score = score super().__init__() @@ -722,7 +722,7 @@ class SearchMoreLikeThis(SearchExpression): """ def __init__(self, documents, score=None): - self.documents = self.cast_as_value(documents) + self.documents = cast_as_value(documents) self.score = score super().__init__() @@ -929,12 +929,12 @@ def __init__( exact=None, filter=None, ): - self.path = self.cast_as_field(path) - self.query_vector = self.cast_as_value(query_vector) - self.limit = self.cast_as_value(limit) - self.num_candidates = self.cast_as_value(num_candidates) - self.exact = self.cast_as_value(exact) - self.filter = self.cast_as_value(filter) + self.path = cast_as_field(path) + self.query_vector = cast_as_value(query_vector) + self.limit = cast_as_value(limit) + self.num_candidates = cast_as_value(num_candidates) + self.exact = cast_as_value(exact) + self.filter = cast_as_value(filter) super().__init__() def __invert__(self): @@ -1001,8 +1001,11 @@ def as_mql(self, compiler, connection): return {"$vectorSearch": params} -class SearchScoreOption: +class SearchScoreOption(Expression): """Class to mutate scoring on a search operation""" def __init__(self, definitions=None): self.definitions = definitions + + def as_mql(self, compiler, connection): + return self.definitions From fd0877248a806145258f0766ca8af48c85e7a6ee Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 14:23:56 -0300 Subject: [PATCH 29/37] Add skip flag. --- tests/queries_/test_search.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 03436d1c7..24ed57105 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -4,7 +4,7 @@ from django.db import connection from django.db.utils import DatabaseError -from django.test import TransactionTestCase +from django.test import TransactionTestCase, skipUnlessDBFeature from pymongo.operations import SearchIndexModel from django_mongodb_backend.expressions.search import ( @@ -65,6 +65,7 @@ def _inner_wait_loop(predicate: Callable): return _inner_wait_loop +@skipUnlessDBFeature("supports_atlas_search") class SearchUtilsMixin(TransactionTestCase): available_apps = [] @@ -87,6 +88,7 @@ def _tear_down(self, model): wait_for_assertion = _wait_for_assertion(timeout=3) +@skipUnlessDBFeature("supports_atlas_search") class SearchEqualsTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -106,6 +108,7 @@ def test_search_equals(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") class SearchAutocompleteTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -170,6 +173,7 @@ def test_search_autocomplete_embedded_model(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") class SearchExistsTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -184,6 +188,7 @@ def test_search_exists(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") class SearchInTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -203,6 +208,7 @@ def test_search_in(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") class SearchPhraseTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -224,6 +230,7 @@ def test_search_phrase(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") class SearchRangeTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -243,6 +250,7 @@ def test_search_range(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.number20])) +@skipUnlessDBFeature("supports_atlas_search") class SearchRegexTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -269,6 +277,7 @@ def test_search_regex(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") class SearchTextTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -298,6 +307,7 @@ def test_search_text_with_fuzzy_and_criteria(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") class SearchWildcardTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -322,6 +332,7 @@ def test_search_wildcard(self): self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) +@skipUnlessDBFeature("supports_atlas_search") class SearchGeoShapeTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -356,6 +367,7 @@ def test_search_geo_shape(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") class SearchGeoWithinTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -389,6 +401,7 @@ def test_search_geo_within(self): self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) +@skipUnlessDBFeature("supports_atlas_search") @unittest.expectedFailure class SearchMoreLikeThisTest(SearchUtilsMixin): def setUp(self): @@ -436,6 +449,7 @@ def test_search_more_like_this(self): ) +@skipUnlessDBFeature("supports_atlas_search") class CompoundSearchTest(SearchUtilsMixin): def setUp(self): self.create_search_index( @@ -505,6 +519,7 @@ def test_compound_operations(self): ) +@skipUnlessDBFeature("supports_atlas_search") class SearchVectorTest(SearchUtilsMixin): def setUp(self): self.create_search_index( From bbfd2b687fc94754f45a5fdb1c3b27ce91ff89f5 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 14:25:47 -0300 Subject: [PATCH 30/37] Edits --- django_mongodb_backend/expressions/builtins.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 6f1289c5f..da95d5fe2 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -74,10 +74,6 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001 return f"{prefix}{self.target.column}" if as_path else f"${prefix}{self.target.column}" -def col_as_path(self, compiler, connection): - return col(self, compiler, connection).lstrip("$") - - def col_pairs(self, compiler, connection): cols = self.get_cols() if len(cols) > 1: From 4f99a4d7317bc836fc810d6d02f4908a86b1d9ad Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 20:19:21 -0300 Subject: [PATCH 31/37] Add search text lookup. --- django_mongodb_backend/compiler.py | 14 ++++++++++ django_mongodb_backend/expressions/search.py | 29 +++++++++++++++++++- tests/queries_/test_search.py | 4 +++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index a6b373afd..abad851a9 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -128,6 +128,9 @@ def _prepare_search_query_for_aggregation_pipeline(self, order_by): self._prepare_search_expressions_for_pipeline( self.having, annotation_group_idx, replacements ) + self._prepare_search_expressions_for_pipeline( + self.get_where(), annotation_group_idx, replacements + ) return replacements def _prepare_annotations_for_aggregation_pipeline(self, order_by): @@ -291,6 +294,8 @@ def pre_sql_setup(self, with_col_aliases=False): for target, expr in self.query.annotation_select.items() } self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by] + where_ = self.get_where().replace_expressions(all_replacements) + self.set_where(where_) return extra_select, order_by, group_by def execute_sql( @@ -692,6 +697,9 @@ def _get_ordering(self): def get_where(self): return getattr(self, "where", self.query.where) + def set_where(self, value): + self.where = value + def explain_query(self): # Validate format (none supported) and options. options = self.connection.ops.explain_query_prefix( @@ -778,6 +786,9 @@ def check_query(self): def get_where(self): return self.query.where + def set_where(self, value): + self.query.where = value + @cached_property def collection_name(self): return self.query.base_table @@ -849,6 +860,9 @@ def check_query(self): def get_where(self): return self.query.where + def set_where(self, value): + self.query.where = value + @cached_property def collection_name(self): return self.query.base_table diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index f826a17c6..1c7f8cdba 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -1,6 +1,9 @@ from django.db import NotSupportedError -from django.db.models import Expression, FloatField +from django.db.models import CharField, Expression, FloatField, TextField from django.db.models.expressions import F, Value +from django.db.models.lookups import Lookup + +from ..query_utils import process_lhs, process_rhs def cast_as_value(value): @@ -1009,3 +1012,27 @@ def __init__(self, definitions=None): def as_mql(self, compiler, connection): return self.definitions + + +class SearchTextLookup(Lookup): + lookup_name = "search" + + def __init__(self, lhs, rhs): + super().__init__(lhs, rhs) + self.lhs = SearchText(self.lhs, self.rhs) + self.rhs = Value(0) + + def __str__(self): + return f"SearchText({self.lhs}, {self.rhs})" + + def __repr__(self): + return f"SearchText({self.lhs}, {self.rhs})" + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return {"$gte": [lhs_mql, value]} + + +CharField.register_lookup(SearchTextLookup) +TextField.register_lookup(SearchTextLookup) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 24ed57105..eed4fe2f4 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -298,6 +298,10 @@ def test_search_text(self): qs = Article.objects.annotate(score=SearchText(path="body", query="lazy")) self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) + def test_search_lookup(self): + qs = Article.objects.filter(body__search="lazy") + self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) + def test_search_text_with_fuzzy_and_criteria(self): qs = Article.objects.annotate( score=SearchText( From d4cbe18e5a9efd6d9c25740fd31a515369991a2a Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 20:21:50 -0300 Subject: [PATCH 32/37] Remove unused method. --- django_mongodb_backend/compiler.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index abad851a9..1ddd2658a 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -209,9 +209,6 @@ def _get_group_id_expressions(self, order_by): ids = self.get_project_fields(tuple(columns), force_expression=True) return ids, replacements - def _build_search_pipeline(self, search_queries): - pass - def _build_aggregation_pipeline(self, ids, group): """Build the aggregation pipeline for grouping.""" pipeline = [] From a467a573c50198cbe0e8029da483ded810b9009b Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 12 Jul 2025 20:26:53 -0300 Subject: [PATCH 33/37] Edits. --- django_mongodb_backend/query_utils.py | 2 +- tests/queries_/test_search.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index c03a0f7ab..0bb292995 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -4,7 +4,7 @@ def is_direct_value(node): - return not hasattr(node, "as_sql") and not hasattr(node, "as_mql") + return not hasattr(node, "as_sql") def process_lhs(node, compiler, connection): diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index eed4fe2f4..f0d113ea2 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -296,11 +296,11 @@ def tearDown(self): def test_search_text(self): qs = Article.objects.annotate(score=SearchText(path="body", query="lazy")) - self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) + self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs.all())) def test_search_lookup(self): qs = Article.objects.filter(body__search="lazy") - self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) + self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs.all())) def test_search_text_with_fuzzy_and_criteria(self): qs = Article.objects.annotate( From a17102c5667fd17dc3943eb5d098e06dc2190b25 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 13 Jul 2025 05:45:32 -0300 Subject: [PATCH 34/37] Fix replacements. --- django_mongodb_backend/compiler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 1ddd2658a..3e1ccc7fc 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -291,8 +291,9 @@ def pre_sql_setup(self, with_col_aliases=False): for target, expr in self.query.annotation_select.items() } self.order_by_objs = [expr.replace_expressions(all_replacements) for expr, _ in order_by] - where_ = self.get_where().replace_expressions(all_replacements) - self.set_where(where_) + if (where := self.get_where()) and search_replacements: + where = where.replace_expressions(search_replacements) + self.set_where(where) return extra_select, order_by, group_by def execute_sql( From 8764f42135329860be34890de03f720db14227db Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 13 Jul 2025 21:26:42 -0300 Subject: [PATCH 35/37] Edits. --- django_mongodb_backend/compiler.py | 25 ++++- django_mongodb_backend/expressions/search.py | 99 ++++++++------------ tests/queries_/test_search.py | 66 ++++++++++++- 3 files changed, 127 insertions(+), 63 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 3e1ccc7fc..2c3142abc 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -241,7 +241,28 @@ def _compound_searches_queries(self, search_replacements): if not search_replacements: return [] if len(search_replacements) > 1: - raise ValueError("Cannot perform more than one search operation.") + has_search = any(not isinstance(search, SearchVector) for search in search_replacements) + has_vector_search = any( + isinstance(search, SearchVector) for search in search_replacements + ) + if has_search and has_vector_search: + raise ValueError( + "Cannot combine a `$vectorSearch` with a `$search` operator. " + "If you need to combine them, consider restructuring your query logic or " + "running them as separate queries." + ) + if not has_search: + raise ValueError( + "Cannot combine two `$vectorSearch` operator. " + "If you need to combine them, consider restructuring your query logic or " + "running them as separate queries." + ) + raise ValueError( + "Only one $search operation is allowed per query. " + f"Received {len(search_replacements)} search expressions. " + "To combine multiple search expressions, use either a CompoundExpression for " + "fine-grained control or CombinedSearchExpression for simple logical combinations." + ) pipeline = [] for search, result_col in search_replacements.items(): score_function = ( @@ -252,7 +273,7 @@ def _compound_searches_queries(self, search_replacements): search.as_mql(self, self.connection), { "$addFields": { - result_col.as_mql(self, self.connection).removeprefix("$"): { + result_col.as_mql(self, self.connection, as_path=True): { "$meta": score_function } } diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 1c7f8cdba..e983443d8 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -71,7 +71,7 @@ def __or__(self, other): return self._combine(other, Operator(Operator.OR)) def __ror__(self, other): - return self._combine(self, Operator(Operator.OR), other) + return self._combine(other, Operator(Operator.OR)) class SearchExpression(SearchCombinable, Expression): @@ -101,10 +101,14 @@ def get_source_expressions(self): return [] def _get_indexed_fields(self, mappings): - for field, definition in mappings.get("fields", {}).items(): - yield field - for path in self._get_indexed_fields(definition): - yield f"{field}.{path}" + if isinstance(mappings, list): + for definition in mappings: + yield from self._get_indexed_fields(definition) + else: + for field, definition in mappings.get("fields", {}).items(): + yield field + for path in self._get_indexed_fields(definition): + yield f"{field}.{path}" def _get_query_index(self, fields, compiler): fields = set(fields) @@ -142,9 +146,7 @@ class SearchAutocomplete(SearchExpression): any-order token matching. score: Optional expression to adjust score relevance (e.g., `{"boost": {"value": 5}}`). - Notes: - * Requires an Atlas Search index with `autocomplete` mappings. - * The operator is injected under the `$search` stage in the aggregation pipeline. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/autocomplete/ """ def __init__(self, path, query, fuzzy=None, token_order=None, score=None): @@ -193,10 +195,7 @@ class SearchEquals(SearchExpression): value: The exact value to match against. score: Optional expression to modify the relevance score. - Notes: - * The field must be indexed with a supported type for `equals`. - * Supports numeric, string, boolean, and date values. - * Score boosting can be applied using the `score` parameter. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/equals/ """ def __init__(self, path, value, score=None): @@ -239,9 +238,7 @@ class SearchExists(SearchExpression): path: The document path to check (as string or expression). score: Optional expression to modify the relevance score. - Notes: - * The target field must be mapped in the Atlas Search index. - * This does not test for null—only for presence. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/exists/ """ def __init__(self, path, score=None): @@ -268,6 +265,23 @@ def search_operator(self, compiler, connection): class SearchIn(SearchExpression): + """ + Atlas Search expression that matches documents where the field value is in a given list. + + This expression uses the **in** operator to match documents whose field + contains a value from the provided array of values. + + Example: + SearchIn("status", ["pending", "approved", "rejected"]) + + Args: + path: The document path to match against (as string or expression). + value: A list of values to check for membership. + score: Optional expression to adjust the relevance score. + + Reference: https://www.mongodb.com/docs/atlas/atlas-search/in/ + """ + def __init__(self, path, value, score=None): self.path = cast_as_field(path) self.value = cast_as_value(value) @@ -297,7 +311,7 @@ class SearchPhrase(SearchExpression): """ Atlas Search expression that matches a phrase in the specified field. - This expression uses the **phrase** operator to search for exact or near-exact + This expression uses the **phrase** operator to search for exact or near exact sequences of terms. It supports optional slop (word distance) and synonym sets. Example: @@ -310,10 +324,7 @@ class SearchPhrase(SearchExpression): synonyms: Optional name of a synonym mapping defined in the Atlas index. score: Optional expression to modify the relevance score. - Notes: - * The field must be mapped as `"type": "string"` with appropriate analyzers. - * Slop allows flexibility in word positioning, like `"quick brown fox"` - matching `"quick fox"` if `slop=1`. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/phrase/ """ def __init__(self, path, query, slop=None, synonyms=None, score=None): @@ -363,9 +374,7 @@ class SearchQueryString(SearchExpression): query: The Lucene-style query string. score: Optional expression to modify the relevance score. - Notes: - * The query string syntax must conform to Atlas Search rules. - * This operator is powerful but can be harder to validate or sanitize. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/queryString/ """ def __init__(self, path, query, score=None): @@ -411,9 +420,7 @@ class SearchRange(SearchExpression): gte: Optional inclusive lower bound (`>=`). score: Optional expression to modify the relevance score. - Notes: - * At least one of `lt`, `lte`, `gt`, or `gte` must be provided. - * The field must be mapped in the Atlas Search index as a comparable type. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/range/ """ def __init__(self, path, lt=None, lte=None, gt=None, gte=None, score=None): @@ -467,10 +474,7 @@ class SearchRegex(SearchExpression): allow_analyzed_field: Whether to allow matching against analyzed fields (default is False). score: Optional expression to modify the relevance score. - Notes: - * Regular expressions must follow JavaScript regex syntax. - * By default, the field must be mapped as `"analyzer": "keyword"` - unless `allow_analyzed_field=True`. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/regex/ """ def __init__(self, path, query, allow_analyzed_field=None, score=None): @@ -519,9 +523,7 @@ class SearchText(SearchExpression): synonyms: Optional name of a synonym mapping defined in the Atlas index. score: Optional expression to adjust relevance scoring. - Notes: - * The target field must be indexed for full-text search in Atlas. - * Fuzzy matching helps match terms with minor typos or variations. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/text/ """ def __init__(self, path, query, fuzzy=None, match_criteria=None, synonyms=None, score=None): @@ -574,11 +576,7 @@ class SearchWildcard(SearchExpression): allow_analyzed_field: Whether to allow matching against analyzed fields (default is False). score: Optional expression to modify the relevance score. - Notes: - * Wildcard patterns follow standard syntax, where `*` matches any sequence of characters - and `?` matches a single character. - * By default, the field should be keyword or unanalyzed - unless `allow_analyzed_field=True`. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/wildcard/ """ def __init__(self, path, query, allow_analyzed_field=None, score=None): @@ -625,9 +623,7 @@ class SearchGeoShape(SearchExpression): geometry: The GeoJSON geometry to compare against. score: Optional expression to modify the relevance score. - Notes: - * The field must be indexed as a geo shape type in Atlas Search. - * Geometry must conform to GeoJSON specification. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/geoShape/ """ def __init__(self, path, relation, geometry, score=None): @@ -674,9 +670,7 @@ class SearchGeoWithin(SearchExpression): geo_object: The GeoJSON geometry defining the boundary. score: Optional expression to adjust the relevance score. - Notes: - * The geo field must be indexed appropriately in the Atlas Search index. - * The geometry must follow GeoJSON format. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/geoWithin/ """ def __init__(self, path, kind, geo_object, score=None): @@ -719,9 +713,7 @@ class SearchMoreLikeThis(SearchExpression): documents: A list of example documents or expressions to find similar documents. score: Optional expression to modify the relevance scoring. - Notes: - * The documents should be representative examples to base similarity on. - * Supports various field types depending on the Atlas Search configuration. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/morelikethis/ """ def __init__(self, documents, score=None): @@ -774,9 +766,7 @@ class CompoundExpression(SearchExpression): score: Optional expression to adjust scoring. minimum_should_match: Minimum number of `should` clauses that must match. - Notes: - * This is the most flexible way to build complex Atlas Search queries. - * Supports nesting of expressions to any depth. + Reference: https://www.mongodb.com/docs/atlas/atlas-search/compound/ """ def __init__( @@ -859,10 +849,6 @@ class CombinedSearchExpression(SearchExpression): lhs: The left-hand search expression. operator: The boolean operator as a string (e.g., "and", "or", "not"). rhs: The right-hand search expression. - - Notes: - * The operator must be supported by MongoDB Atlas Search boolean logic. - * This class enables building complex nested search queries. """ def __init__(self, lhs, operator, rhs): @@ -917,10 +903,7 @@ class SearchVector(SearchExpression): exact: Optional flag to enforce exact matching. filter: Optional filter expression to narrow candidate documents. - Notes: - * The vector field must be indexed as a vector type in Atlas Search. - * Parameters like `num_candidates` and `exact` control search - performance and accuracy trade-offs. + Reference: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ """ def __init__( diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index f0d113ea2..ac9d175ac 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -3,6 +3,7 @@ from time import monotonic, sleep from django.db import connection +from django.db.models import Q from django.db.utils import DatabaseError from django.test import TransactionTestCase, skipUnlessDBFeature from pymongo.operations import SearchIndexModel @@ -463,7 +464,7 @@ def setUp(self): "mappings": { "dynamic": False, "fields": { - "headline": {"type": "token"}, + "headline": [{"type": "token"}, {"type": "string"}], "body": {"type": "string"}, "number": {"type": "number"}, }, @@ -498,7 +499,7 @@ def tearDown(self): self._tear_down(Article) super().tearDown() - def test_compound_expression(self): + def test_expression(self): must_expr = SearchEquals(path="headline", value="space exploration") must_not_expr = SearchPhrase(path="body", query="icy moons") should_expr = SearchPhrase(path="body", query="exoplanets") @@ -513,7 +514,7 @@ def test_compound_expression(self): qs = Article.objects.annotate(score=compound).order_by("score") self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.exoplanet])) - def test_compound_operations(self): + def test_operations(self): expr = SearchEquals(path="headline", value="space exploration") & ~SearchEquals( path="number", value=3 ) @@ -522,6 +523,65 @@ def test_compound_operations(self): lambda: self.assertCountEqual(qs.all(), [self.mars_mission, self.exoplanet]) ) + def test_multiple_search(self): + msg = ( + "Only one $search operation is allowed per query. Received 2 search expressions. " + "To combine multiple search expressions, use either a CompoundExpression for " + "fine-grained control or CombinedSearchExpression for simple logical combinations." + ) + with self.assertRaisesMessage(ValueError, msg): + Article.objects.annotate( + score1=SearchEquals(path="headline", value="space exploration"), + score2=~SearchEquals(path="number", value=3), + ).order_by("score1", "score2").first() + + with self.assertRaisesMessage(ValueError, msg): + Article.objects.filter( + Q(headline__search="space exploration"), Q(headline__search="space exploration 2") + ).first() + + def test_multiple_type_search(self): + msg = ( + "Cannot combine a `$vectorSearch` with a `$search` operator. " + "If you need to combine them, consider " + "restructuring your query logic or running them as separate queries." + ) + with self.assertRaisesMessage(ValueError, msg): + Article.objects.annotate( + score1=SearchEquals(path="headline", value="space exploration"), + score2=SearchVector( + path="headline", + query_vector=[1, 2, 3], + num_candidates=5, + limit=2, + ), + ).order_by("score1", "score2").first() + + def test_multiple_vector_search(self): + msg = ( + "Cannot combine two `$vectorSearch` operator. If you need to combine them, " + "consider restructuring your query logic or running them as separate queries." + ) + with self.assertRaisesMessage(ValueError, msg): + Article.objects.annotate( + score1=SearchVector( + path="headline", + query_vector=[1, 2, 3], + num_candidates=5, + limit=2, + ), + score2=SearchVector( + path="headline", + query_vector=[1, 2, 4], + num_candidates=5, + limit=2, + ), + ).order_by("score1", "score2").first() + + def test_search_and_filter(self): + qs = Article.objects.filter(headline__search="space exploration", number__gt=2) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.icy_moons])) + @skipUnlessDBFeature("supports_atlas_search") class SearchVectorTest(SearchUtilsMixin): From 47bc62ba3cf600d7f4782b0681c25d9d6a6d9d59 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sun, 13 Jul 2025 21:28:30 -0300 Subject: [PATCH 36/37] Edits. --- tests/queries_/test_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index ac9d175ac..5ba1d96e5 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -40,7 +40,7 @@ def _wait_for_assertion(timeout: float = 120, interval: float = 0.5) -> None: """ @staticmethod - def _inner_wait_loop(predicate: Callable): + def inner_wait_loop(predicate: Callable): """ Waits until the given predicate stops raising AssertionError or DatabaseError. @@ -63,7 +63,7 @@ def _inner_wait_loop(predicate: Callable): else: break - return _inner_wait_loop + return inner_wait_loop @skipUnlessDBFeature("supports_atlas_search") From e0f6ed19f542904467dc907188ca96d01ca2eabc Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 14 Jul 2025 23:48:37 -0300 Subject: [PATCH 37/37] Add score unit test --- django_mongodb_backend/expressions/search.py | 8 +- tests/queries_/test_search.py | 219 ++++++++++++++++++- 2 files changed, 216 insertions(+), 11 deletions(-) diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index e983443d8..c3f883bf2 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -260,7 +260,7 @@ def search_operator(self, compiler, connection): "path": self.path.as_mql(compiler, connection, as_path=True), } if self.score is not None: - params["score"] = self.score.definitions + params["score"] = self.score.as_mql(compiler, connection) return {"exists": params} @@ -601,7 +601,7 @@ def search_operator(self, compiler, connection): "query": self.query.value, } if self.score: - params["score"] = self.score.query.as_mql(compiler, connection) + params["score"] = self.score.as_mql(compiler, connection) if self.allow_analyzed_field is not None: params["allowAnalyzedField"] = self.allow_analyzed_field.value return {"wildcard": params} @@ -991,10 +991,10 @@ class SearchScoreOption(Expression): """Class to mutate scoring on a search operation""" def __init__(self, definitions=None): - self.definitions = definitions + self._definitions = definitions def as_mql(self, compiler, connection): - return self.definitions + return self._definitions class SearchTextLookup(Lookup): diff --git a/tests/queries_/test_search.py b/tests/queries_/test_search.py index 5ba1d96e5..fdcb5bc7d 100644 --- a/tests/queries_/test_search.py +++ b/tests/queries_/test_search.py @@ -20,6 +20,7 @@ SearchPhrase, SearchRange, SearchRegex, + SearchScoreOption, SearchText, SearchVector, SearchWildcard, @@ -70,18 +71,16 @@ def inner_wait_loop(predicate: Callable): class SearchUtilsMixin(TransactionTestCase): available_apps = [] - @staticmethod - def _get_collection(model): + def _get_collection(self, model): return connection.database.get_collection(model._meta.db_table) - @staticmethod - def create_search_index(model, index_name, definition, type="search"): - collection = SearchUtilsMixin._get_collection(model) + def create_search_index(self, model, index_name, definition, type="search"): + collection = self._get_collection(model) idx = SearchIndexModel(definition=definition, name=index_name, type=type) collection.create_search_index(idx) def _tear_down(self, model): - collection = SearchUtilsMixin._get_collection(model) + collection = self._get_collection(model) for search_indexes in collection.list_search_indexes(): collection.drop_search_index(search_indexes["name"]) collection.delete_many({}) @@ -95,7 +94,12 @@ def setUp(self): self.create_search_index( Article, "equals_headline_index", - {"mappings": {"dynamic": False, "fields": {"headline": {"type": "token"}}}}, + { + "mappings": { + "dynamic": False, + "fields": {"headline": {"type": "token"}, "number": {"type": "number"}}, + } + }, ) self.article = Article.objects.create(headline="cross", number=1, body="body") Article.objects.create(headline="other thing", number=2, body="body") @@ -108,6 +112,44 @@ def test_search_equals(self): qs = Article.objects.annotate(score=SearchEquals(path="headline", value="cross")) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_boost_score(self): + boost_score = SearchScoreOption({"boost": {"value": 3}}) + + qs = Article.objects.annotate( + score=SearchEquals(path="headline", value="cross", score=boost_score) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertGreaterEqual(scored.score, 3.0) + + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchEquals(path="headline", value="cross", score=constant_score) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + + def test_function_score(self): + function_score = SearchScoreOption( + { + "function": { + "path": { + "value": "number", + "undefined": 0, + }, + } + } + ) + + qs = Article.objects.annotate( + score=SearchEquals(path="headline", value="cross", score=function_score) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 1.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchAutocompleteTest(SearchUtilsMixin): @@ -173,6 +215,21 @@ def test_search_autocomplete_embedded_model(self): ) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchAutocomplete( + path="headline", + query="crossing", + token_order="sequential", # noqa: S106 + fuzzy={"maxEdits": 2}, + score=constant_score, + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchExistsTest(SearchUtilsMixin): @@ -184,10 +241,21 @@ def setUp(self): ) self.article = Article.objects.create(headline="ignored", number=3, body="something") + def tearDown(self): + self._tear_down(Article) + super().tearDown() + def test_search_exists(self): qs = Article.objects.annotate(score=SearchExists(path="body")) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate(score=SearchExists(path="body", score=constant_score)) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchInTest(SearchUtilsMixin): @@ -208,6 +276,15 @@ def test_search_in(self): qs = Article.objects.annotate(score=SearchIn(path="headline", value=["cross", "river"])) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchIn(path="headline", value=["cross", "river"], score=constant_score) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchPhraseTest(SearchUtilsMixin): @@ -230,6 +307,15 @@ def test_search_phrase(self): qs = Article.objects.annotate(score=SearchPhrase(path="body", query="quick brown")) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchPhrase(path="body", query="quick brown", score=constant_score) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchRangeTest(SearchUtilsMixin): @@ -250,6 +336,15 @@ def test_search_range(self): qs = Article.objects.annotate(score=SearchRange(path="number", gte=10, lt=30)) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.number20])) + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchRange(path="number", gte=10, lt=30, score=constant_score) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.number20])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchRegexTest(SearchUtilsMixin): @@ -277,6 +372,17 @@ def test_search_regex(self): ) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchRegex( + path="headline", query="hello.*", allow_analyzed_field=True, score=constant_score + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchTextTest(SearchUtilsMixin): @@ -311,6 +417,21 @@ def test_search_text_with_fuzzy_and_criteria(self): ) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchText( + path="body", + query="lazzy", + fuzzy={"maxEdits": 2}, + match_criteria="all", + score=constant_score, + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchWildcardTest(SearchUtilsMixin): @@ -336,6 +457,15 @@ def test_search_wildcard(self): qs = Article.objects.annotate(score=SearchWildcard(path="headline", query="dark-*")) self.wait_for_assertion(lambda: self.assertCountEqual([self.article], qs)) + def test_constant_score(self): + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchWildcard(path="headline", query="dark-*", score=constant_score) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchGeoShapeTest(SearchUtilsMixin): @@ -371,6 +501,21 @@ def test_search_geo_shape(self): ) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_constant_score(self): + polygon = { + "type": "Polygon", + "coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]], + } + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchGeoShape( + path="location", relation="within", geometry=polygon, score=constant_score + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") class SearchGeoWithinTest(SearchUtilsMixin): @@ -405,6 +550,24 @@ def test_search_geo_within(self): ) self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + def test_constant_score(self): + polygon = { + "type": "Polygon", + "coordinates": [[[30, 0], [50, 0], [50, 10], [30, 10], [30, 0]]], + } + constant_score = SearchScoreOption({"constant": {"value": 10}}) + qs = Article.objects.annotate( + score=SearchGeoWithin( + path="location", + kind="geometry", + geo_object=polygon, + score=constant_score, + ) + ) + self.wait_for_assertion(lambda: self.assertCountEqual(qs.all(), [self.article])) + scored = qs.first() + self.assertAlmostEqual(scored.score, 10.0, places=2) + @skipUnlessDBFeature("supports_atlas_search") @unittest.expectedFailure @@ -523,6 +686,48 @@ def test_operations(self): lambda: self.assertCountEqual(qs.all(), [self.mars_mission, self.exoplanet]) ) + def test_mixed_scores(self): + boost_score = SearchScoreOption({"boost": {"value": 5}}) + constant_score = SearchScoreOption({"constant": {"value": 20}}) + function_score = SearchScoreOption( + {"function": {"path": {"value": "number", "undefined": 0}}} + ) + + must_expr = SearchEquals(path="headline", value="space exploration", score=boost_score) + should_expr = SearchPhrase(path="body", query="exoplanets", score=constant_score) + must_not_expr = SearchPhrase(path="body", query="icy moons", score=function_score) + + compound = CompoundExpression( + must=[must_expr], + must_not=[must_not_expr], + should=[should_expr], + ) + qs = Article.objects.annotate(score=compound).order_by("-score") + self.wait_for_assertion( + lambda: self.assertListEqual(list(qs.all()), [self.exoplanet, self.mars_mission]) + ) + # Exoplanet should rank first because of the constant 20 bump. + self.assertEqual(qs.first(), self.exoplanet) + + def test_operationss_with_function_score(self): + function_score = SearchScoreOption( + {"function": {"path": {"value": "number", "undefined": 0}}} + ) + + expr = SearchEquals( + path="headline", + value="space exploration", + score=function_score, + ) & ~SearchEquals(path="number", value=3) + + qs = Article.objects.annotate(score=expr).order_by("-score") + + self.wait_for_assertion( + lambda: self.assertListEqual(list(qs), [self.exoplanet, self.mars_mission]) + ) + # Returns mars_mission (score≈1) and exoplanet (score≈2) then; exoplanet first. + self.assertEqual(qs.first(), self.exoplanet) + def test_multiple_search(self): msg = ( "Only one $search operation is allowed per query. Received 2 search expressions. "