From e389b2ec0bab7f75968e0b458d919e035309b65c Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 9 Sep 2025 14:15:26 -0300 Subject: [PATCH 01/23] Functional approach solution --- django_mongodb_backend/base.py | 58 ++++- django_mongodb_backend/compiler.py | 18 +- .../expressions/builtins.py | 111 ++++++--- django_mongodb_backend/expressions/search.py | 10 +- django_mongodb_backend/features.py | 3 - django_mongodb_backend/fields/array.py | 83 +++++-- .../fields/embedded_model_array.py | 16 +- django_mongodb_backend/fields/json.py | 103 +++++--- .../polymorphic_embedded_model_array.py | 2 +- django_mongodb_backend/functions.py | 113 +++++---- django_mongodb_backend/gis/lookups.py | 2 +- django_mongodb_backend/indexes.py | 2 +- django_mongodb_backend/lookups.py | 28 ++- django_mongodb_backend/query.py | 65 +++--- django_mongodb_backend/query_utils.py | 25 +- tests/lookup_/tests.py | 10 +- tests/queries_/test_mql.py | 220 ++++++++++-------- 17 files changed, 572 insertions(+), 297 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index f751c27fa..f882e1eec 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -20,7 +20,7 @@ from .features import DatabaseFeatures from .introspection import DatabaseIntrospection from .operations import DatabaseOperations -from .query_utils import regex_match +from .query_utils import regex_expr, regex_match from .schema import DatabaseSchemaEditor from .utils import OperationDebugWrapper from .validation import DatabaseValidation @@ -108,7 +108,12 @@ def _isnull_operator(a, b): } return is_null if b else {"$not": is_null} - mongo_operators = { + def _isnull_operator_match(a, b): + if b: + return {"$or": [{a: {"$exists": False}}, {a: None}]} + return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]} + + mongo_operators_expr = { "exact": lambda a, b: {"$eq": [a, b]}, "gt": lambda a, b: {"$gt": [a, b]}, "gte": lambda a, b: {"$gte": [a, b]}, @@ -118,7 +123,7 @@ def _isnull_operator(a, b): "lte": lambda a, b: { "$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator(a, False)] }, - "in": lambda a, b: {"$in": [a, b]}, + "in": lambda a, b: {"$in": (a, b)}, "isnull": _isnull_operator, "range": lambda a, b: { "$and": [ @@ -126,11 +131,48 @@ def _isnull_operator(a, b): {"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]}, ] }, - "iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True), - "startswith": lambda a, b: regex_match(a, ("^", b)), - "istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True), - "endswith": lambda a, b: regex_match(a, (b, {"$literal": "$"})), - "iendswith": lambda a, b: regex_match(a, (b, {"$literal": "$"}), insensitive=True), + "iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True), + "startswith": lambda a, b: regex_expr(a, ("^", b)), + "istartswith": lambda a, b: regex_expr(a, ("^", b), insensitive=True), + "endswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"})), + "iendswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"}), insensitive=True), + "contains": lambda a, b: regex_expr(a, b), + "icontains": lambda a, b: regex_expr(a, b, insensitive=True), + "regex": lambda a, b: regex_expr(a, b), + "iregex": lambda a, b: regex_expr(a, b, insensitive=True), + } + + def range_match(a, b): + ## TODO: MAKE A TEST TO TEST WHEN BOTH ENDS ARE NONE. WHAT SHALL I RETURN? + conditions = [] + if b[0] is not None: + conditions.append({a: {"$gte": b[0]}}) + if b[1] is not None: + conditions.append({a: {"$lte": b[1]}}) + if not conditions: + return {"$literal": True} + return {"$and": conditions} + + mongo_operators_match = { + "exact": lambda a, b: {a: b}, + "gt": lambda a, b: {a: {"$gt": b}}, + "gte": lambda a, b: {a: {"$gte": b}}, + # MongoDB considers null less than zero. Exclude null values to match + # SQL behavior. + "lt": lambda a, b: { + "$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator_match(a, False)] + }, + "lte": lambda a, b: { + "$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)] + }, + "in": lambda a, b: {a: {"$in": tuple(b)}}, + "isnull": _isnull_operator_match, + "range": range_match, + "iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True), + "startswith": lambda a, b: regex_match(a, f"^{b}"), + "istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True), + "endswith": lambda a, b: regex_match(a, f"{b}$"), + "iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True), "contains": lambda a, b: regex_match(a, b), "icontains": lambda a, b: regex_match(a, b, insensitive=True), "regex": lambda a, b: regex_match(a, b), diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 628a91e84..32f190a3b 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -327,14 +327,14 @@ def pre_sql_setup(self, with_col_aliases=False): pipeline = self._build_aggregation_pipeline(ids, group) if self.having: having = self.having.replace_expressions(all_replacements).as_mql( - self, self.connection + self, self.connection, as_path=True ) # Add HAVING subqueries. for query in self.subqueries or (): pipeline.extend(query.get_pipeline()) # Remove the added subqueries. self.subqueries = [] - pipeline.append({"$match": {"$expr": having}}) + pipeline.append({"$match": having}) self.aggregation_pipeline = pipeline self.annotations = { target: expr.replace_expressions(all_replacements) @@ -481,11 +481,11 @@ def build_query(self, columns=None): query.lookup_pipeline = self.get_lookup_pipeline() where = self.get_where() try: - expr = where.as_mql(self, self.connection) if where else {} + expr = where.as_mql(self, self.connection, as_path=True) if where else {} except FullResultSet: query.match_mql = {} else: - query.match_mql = {"$expr": expr} + query.match_mql = expr if extra_fields: query.extra_fields = self.get_project_fields(extra_fields, force_expression=True) query.subqueries = self.subqueries @@ -643,7 +643,9 @@ def get_combinator_queries(self): for alias, expr in self.columns: # Unfold foreign fields. if isinstance(expr, Col) and expr.alias != self.collection_name: - ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection) + ids[expr.alias][expr.target.column] = expr.as_mql( + self, self.connection, as_path=False + ) else: ids[alias] = f"${alias}" # Convert defaultdict to dict so it doesn't appear as @@ -707,16 +709,16 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False # For brevity/simplicity, project {"field_name": 1} # instead of {"field_name": "$field_name"}. if isinstance(expr, Col) and name == expr.target.column and not force_expression - else expr.as_mql(self, self.connection) + else expr.as_mql(self, self.connection, as_path=False) ) except EmptyResultSet: empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented) value = ( False if empty_result_set_value is NotImplemented else empty_result_set_value ) - fields[collection][name] = Value(value).as_mql(self, self.connection) + fields[collection][name] = Value(value).as_mql(self, self.connection, as_path=False) except FullResultSet: - fields[collection][name] = Value(True).as_mql(self, self.connection) + fields[collection][name] = Value(True).as_mql(self, self.connection, as_path=False) # Annotations (stored in None) and the main collection's fields # should appear in the top-level of the fields dict. fields.update(fields.pop(None, {})) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 0bc939350..212b22e12 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -6,6 +6,7 @@ from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import NotSupportedError from django.db.models.expressions import ( + BaseExpression, Case, Col, ColPairs, @@ -13,6 +14,7 @@ Exists, ExpressionList, ExpressionWrapper, + Func, NegatedExpression, OrderBy, RawSQL, @@ -23,17 +25,20 @@ Value, When, ) +from django.db.models.fields.json import KeyTransform from django.db.models.sql import Query -from ..query_utils import process_lhs +from django_mongodb_backend.fields.array import Array +from ..query_utils import is_direct_value, process_lhs -def case(self, compiler, connection): + +def case(self, compiler, connection, as_path=False): case_parts = [] for case in self.cases: case_mql = {} try: - case_mql["case"] = case.as_mql(compiler, connection) + case_mql["case"] = case.as_mql(compiler, connection, as_path=False) except EmptyResultSet: continue except FullResultSet: @@ -45,12 +50,16 @@ def case(self, compiler, connection): default_mql = self.default.as_mql(compiler, connection) if not case_parts: return default_mql - return { + expr = { "$switch": { "branches": case_parts, "default": default_mql, } } + if as_path: + return {"$expr": expr} + + return expr def col(self, compiler, connection, as_path=False): # noqa: ARG001 @@ -76,34 +85,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001 return f"{prefix}{self.target.column}" -def col_pairs(self, compiler, connection): +def col_pairs(self, compiler, connection, as_path=False): cols = self.get_cols() if len(cols) > 1: raise NotSupportedError("ColPairs is not supported.") - return cols[0].as_mql(compiler, connection) + return cols[0].as_mql(compiler, connection, as_path=as_path) -def combined_expression(self, compiler, connection): +def combined_expression(self, compiler, connection, as_path=False): expressions = [ - self.lhs.as_mql(compiler, connection), - self.rhs.as_mql(compiler, connection), + self.lhs.as_mql(compiler, connection, as_path=as_path), + self.rhs.as_mql(compiler, connection, as_path=as_path), ] return connection.ops.combine_expression(self.connector, expressions) -def expression_wrapper(self, compiler, connection): - return self.expression.as_mql(compiler, connection) +def expression_wrapper(self, compiler, connection, as_path=False): + return self.expression.as_mql(compiler, connection, as_path=as_path) -def negated_expression(self, compiler, connection): - return {"$not": expression_wrapper(self, compiler, connection)} +def negated_expression(self, compiler, connection, as_path=False): + return {"$not": expression_wrapper(self, compiler, connection, as_path=as_path)} def order_by(self, compiler, connection): return self.expression.as_mql(compiler, connection) -def query(self, compiler, connection, get_wrapping_pipeline=None): +def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False): subquery_compiler = self.get_compiler(connection=connection) subquery_compiler.pre_sql_setup(with_col_aliases=False) field_name, expr = subquery_compiler.columns[0] @@ -145,6 +154,8 @@ def query(self, compiler, connection, get_wrapping_pipeline=None): # Erase project_fields since the required value is projected above. subquery.project_fields = None compiler.subqueries.append(subquery) + if as_path: + return f"{table_output}.{field_name}" return f"${table_output}.{field_name}" @@ -152,7 +163,7 @@ def raw_sql(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("RawSQL is not supported on MongoDB.") -def ref(self, compiler, connection): # noqa: ARG001 +def ref(self, compiler, connection, as_path=False): # noqa: ARG001 prefix = ( f"{self.source.alias}." if isinstance(self.source, Col) and self.source.alias != compiler.collection_name @@ -162,32 +173,47 @@ def ref(self, compiler, connection): # noqa: ARG001 refs, _ = compiler.columns[self.ordinal - 1] else: refs = self.refs - return f"${prefix}{refs}" + if not as_path: + prefix = f"${prefix}" + return f"{prefix}{refs}" -def star(self, compiler, connection): # noqa: ARG001 +def star(self, compiler, connection, **extra): # noqa: ARG001 return {"$literal": True} -def subquery(self, compiler, connection, get_wrapping_pipeline=None): - return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) +def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_path=False): + expr = self.query.as_mql( + compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False + ) + if as_path: + return {"$expr": expr} + return expr -def exists(self, compiler, connection, get_wrapping_pipeline=None): +def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False): try: - lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) + lhs_mql = subquery( + self, + compiler, + connection, + get_wrapping_pipeline=get_wrapping_pipeline, + as_path=as_path, + ) except EmptyResultSet: return Value(False).as_mql(compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, False) + if as_path: + return {"$expr": connection.mongo_operators_match["isnull"](lhs_mql, False)} + return connection.mongo_operators_expr["isnull"](lhs_mql, False) -def when(self, compiler, connection): - return self.condition.as_mql(compiler, connection) +def when(self, compiler, connection, as_path=False): + return self.condition.as_mql(compiler, connection, as_path=as_path) -def value(self, compiler, connection): # noqa: ARG001 +def value(self, compiler, connection, as_path=False): # noqa: ARG001 value = self.value - if isinstance(value, (list, int)): + if isinstance(value, (list, int)) and not as_path: # Wrap lists & numbers in $literal to prevent ambiguity when Value # appears in $project. return {"$literal": value} @@ -209,6 +235,36 @@ def value(self, compiler, connection): # noqa: ARG001 return value +@staticmethod +def _is_constant_value(value): + if isinstance(value, list | Array): + iterable = value.get_source_expressions() if isinstance(value, Array) else value + return all(_is_constant_value(e) for e in iterable) + if is_direct_value(value): + return True + return isinstance(value, Func | Value) and not ( + value.contains_aggregate + or value.contains_over_clause + or value.contains_column_references + or value.contains_subquery + ) + + +@staticmethod +def _is_simple_column(lhs): + while isinstance(lhs, KeyTransform): + if "." in getattr(lhs, "key_name", ""): + return False + lhs = lhs.lhs + col = lhs.source if isinstance(lhs, Ref) else lhs + # Foreign columns from parent cannot be addressed as single match + return isinstance(col, Col) and col.alias is not None + + +def _is_simple_expression(self): + return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs) + + def register_expressions(): Case.as_mql = case Col.as_mql = col @@ -227,3 +283,6 @@ def register_expressions(): Subquery.as_mql = subquery When.as_mql = when Value.as_mql = value + BaseExpression.is_simple_expression = _is_simple_expression + BaseExpression.is_simple_column = _is_simple_column + BaseExpression.is_constant_value = _is_constant_value diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 3783f5943..5b74d9232 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -933,10 +933,12 @@ def __str__(self): def __repr__(self): return f"SearchText({self.lhs}, {self.rhs})" - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) - return {"$gte": [lhs_mql, value]} + def as_mql(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) + value = process_rhs(self, compiler, connection, as_path=as_path) + if as_path: + return {lhs_mql: {"$gte": value}} + return {"$expr": {"$gte": [lhs_mql, value]}} CharField.register_lookup(SearchTextLookup) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 7e29c5003..18a048bf6 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -90,9 +90,6 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures): "auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key", # GenericRelation.value_to_string() assumes integer pk. "contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string", - # icontains doesn't work on ArrayField: - # Unsupported conversion from array to string in $convert - "model_fields_.test_arrayfield.QueryingTests.test_icontains", # ArrayField's contained_by lookup crashes with Exists: "both operands " # of $setIsSubset must be arrays. Second argument is of type: null" # https://jira.mongodb.org/browse/SERVER-99186 diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index da64ee2e8..66149c1dc 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -230,8 +230,11 @@ def formfield(self, **kwargs): class Array(Func): - def as_mql(self, compiler, connection): - return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()] + def as_mql(self, compiler, connection, as_path=False): + return [ + expr.as_mql(compiler, connection, as_path=as_path) + for expr in self.get_source_expressions() + ] class ArrayRHSMixin: @@ -251,9 +254,24 @@ def __init__(self, lhs, rhs): class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contains" - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) + def as_mql(self, compiler, connection, as_path=False): + if as_path and self.is_simple_expression(): + lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) + value = process_rhs(self, compiler, connection, as_path=as_path) + if value is None: + return False + return {lhs_mql: {"$all": value}} + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + value = process_rhs(self, compiler, connection, as_path=False) + expr = { + "$and": [ + {"$ne": [lhs_mql, None]}, + {"$ne": [value, None]}, + {"$setIsSubset": [value, lhs_mql]}, + ] + } + if as_path: + return {"$expr": expr} return { "$and": [ {"$ne": [lhs_mql, None]}, @@ -267,16 +285,19 @@ def as_mql(self, compiler, connection): class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contained_by" - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) - return { + def as_mql(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + value = process_rhs(self, compiler, connection, as_path=False) + expr = { "$and": [ {"$ne": [lhs_mql, None]}, {"$ne": [value, None]}, {"$setIsSubset": [lhs_mql, value]}, ] } + if as_path: + return {"$expr": expr} + return expr @ArrayField.register_lookup @@ -323,12 +344,23 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) }, ] - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) - return { - "$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}] + def as_mql(self, compiler, connection, as_path=False): + if as_path and self.is_simple_expression(): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return {lhs_mql: {"$in": value}} + + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + value = process_rhs(self, compiler, connection, as_path=False) + expr = { + "$and": [ + {"$ne": [lhs_mql, None]}, + {"$size": {"$setIntersection": [value, lhs_mql]}}, + ] } + if as_path: + return {"$expr": expr} + return expr @ArrayField.register_lookup @@ -336,9 +368,12 @@ class ArrayLenTransform(Transform): lookup_name = "len" output_field = IntegerField() - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} + def as_mql(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + expr = {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} + if as_path: + return {"$expr": expr} + return expr @ArrayField.register_lookup @@ -363,9 +398,15 @@ def __init__(self, index, base_field, *args, **kwargs): self.index = index self.base_field = base_field - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - return {"$arrayElemAt": [lhs_mql, self.index]} + def as_mql(self, compiler, connection, as_path=False): + if as_path and self.is_simple_column(self.lhs): + lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) + return f"{lhs_mql}.{self.index}" + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + expr = {"$arrayElemAt": [lhs_mql, self.index]} + if as_path: + return {"$expr": expr} + return expr @property def output_field(self): @@ -387,7 +428,7 @@ def __init__(self, start, end, *args, **kwargs): self.start = start self.end = end - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): lhs_mql = process_lhs(self, compiler, connection) return {"$slice": [lhs_mql, self.start, self.end]} diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index e880931c9..fc0872be9 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -76,7 +76,7 @@ def _get_lookup(self, lookup_name): return lookup class EmbeddedModelArrayFieldLookups(Lookup): - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): raise ValueError( "Lookups aren't supported on EmbeddedModelArrayField. " "Try querying one of its embedded fields instead." @@ -115,7 +115,7 @@ def get_lookup(self, name): class EmbeddedModelArrayFieldBuiltinLookup(Lookup): - def process_rhs(self, compiler, connection): + def process_rhs(self, compiler, connection, as_path=False): value = self.rhs if not self.get_db_prep_lookup_value_is_iterable: value = [value] @@ -129,17 +129,17 @@ def process_rhs(self, compiler, connection): for v in value ] - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): # Querying a subfield within the array elements (via nested # KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over # the array and applying $in on the subfield. lhs_mql = process_lhs(self, compiler, connection) inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"] values = process_rhs(self, compiler, connection) - lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name]( + lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators_expr[self.lookup_name]( inner_lhs_mql, values ) - return {"$anyElementTrue": lhs_mql} + return {"$expr": {"$anyElementTrue": lhs_mql}} @_EmbeddedModelArrayOutputField.register_lookup @@ -274,7 +274,11 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): + if as_path: + inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True) + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return f"{inner_lhs_mql}.{lhs_mql}" inner_lhs_mql = self._lhs.as_mql(compiler, connection) lhs_mql = process_lhs(self, compiler, connection) return { diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 1a7ecb615..dd582ae8c 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -1,3 +1,5 @@ +from itertools import chain + from django.db import NotSupportedError from django.db.models.fields.json import ( ContainedBy, @@ -8,6 +10,7 @@ HasKeys, JSONExact, KeyTransform, + KeyTransformExact, KeyTransformIn, KeyTransformIsNull, KeyTransformNumericLookupMixin, @@ -17,8 +20,10 @@ from ..query_utils import process_lhs, process_rhs -def build_json_mql_path(lhs, key_transforms): +def build_json_mql_path(lhs, key_transforms, as_path=False): # Build the MQL path using the collected key transforms. + if as_path and lhs: + return ".".join(chain([lhs], key_transforms)) result = lhs for key in key_transforms: get_field = {"$getField": {"input": result, "field": key}} @@ -37,16 +42,21 @@ def build_json_mql_path(lhs, key_transforms): return result -def contained_by(self, compiler, connection): # noqa: ARG001 +def contained_by(self, compiler, connection, as_path=False): # noqa: ARG001 raise NotSupportedError("contained_by lookup is not supported on this database backend.") -def data_contains(self, compiler, connection): # noqa: ARG001 +def data_contains(self, compiler, connection, as_path=False): # noqa: ARG001 raise NotSupportedError("contains lookup is not supported on this database backend.") -def _has_key_predicate(path, root_column, negated=False): +def _has_key_predicate(path, root_column=None, negated=False, as_path=False): """Return MQL to check for the existence of `path`.""" + if as_path: + # if not negated: + return {path: {"$exists": not negated}} + # return {"$and": [{path: {"$exists": True}}, {path: {"$ne": None}}]} + # return {"$or": [{path: {"$exists": False}}, {path: None}]} result = { "$and": [ # The path must exist (i.e. not be "missing"). @@ -61,24 +71,28 @@ def _has_key_predicate(path, root_column, negated=False): return result -def has_key_lookup(self, compiler, connection): +def has_key_lookup(self, compiler, connection, as_path=False): """Return MQL to check for the existence of a key.""" rhs = self.rhs - lhs = process_lhs(self, compiler, connection) if not isinstance(rhs, (list, tuple)): rhs = [rhs] + as_path = as_path and self.is_simple_expression() and all("." not in v for v in rhs) + lhs = process_lhs(self, compiler, connection, as_path=as_path) paths = [] # Transform any "raw" keys into KeyTransforms to allow consistent handling # in the code that follows. + for key in rhs: rhs_json_path = key if isinstance(key, KeyTransform) else KeyTransform(key, self.lhs) - paths.append(rhs_json_path.as_mql(compiler, connection)) + paths.append(rhs_json_path.as_mql(compiler, connection, as_path=as_path)) keys = [] for path in paths: - keys.append(_has_key_predicate(path, lhs)) - if self.mongo_operator is None: - return keys[0] - return {self.mongo_operator: keys} + keys.append(_has_key_predicate(path, lhs, as_path=as_path)) + + result = keys[0] if self.mongo_operator is None else {self.mongo_operator: keys} + if not as_path: + result = {"$expr": result} + return result _process_rhs = JSONExact.process_rhs @@ -93,7 +107,7 @@ def json_exact_process_rhs(self, compiler, connection): ) -def key_transform(self, compiler, connection): +def key_transform(self, compiler, connection, as_path=False): """ Return MQL for this KeyTransform (JSON path). @@ -104,19 +118,27 @@ def key_transform(self, compiler, connection): """ key_transforms = [self.key_name] previous = self.lhs - # Collect all key transforms in order. while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs - lhs_mql = previous.as_mql(compiler, connection) - return build_json_mql_path(lhs_mql, key_transforms) + if as_path and self.is_simple_column(self.lhs): + lhs_mql = previous.as_mql(compiler, connection, as_path=True) + return build_json_mql_path(lhs_mql, key_transforms, as_path=True) + # Collect all key transforms in order. + lhs_mql = previous.as_mql(compiler, connection, as_path=False) + if as_path: + return {"$expr": build_json_mql_path(lhs_mql, key_transforms, as_path=False)} + return build_json_mql_path(lhs_mql, key_transforms, as_path=False) -def key_transform_in(self, compiler, connection): +def key_transform_in(self, compiler, connection, as_path=False): """ Return MQL to check if a JSON path exists and that its values are in the set of specified values (rhs). """ + if as_path and self.is_simple_expression(): + return builtin_lookup(self, compiler, connection, as_path=True) + lhs_mql = process_lhs(self, compiler, connection) # Traverse to the root column. previous = self.lhs @@ -125,11 +147,14 @@ def key_transform_in(self, compiler, connection): root_column = previous.as_mql(compiler, connection) value = process_rhs(self, compiler, connection) # Construct the expression to check if lhs_mql values are in rhs values. - expr = connection.mongo_operators[self.lookup_name](lhs_mql, value) - return {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} + expr = connection.mongo_operators_expr[self.lookup_name](lhs_mql, value) + expr = {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} + if as_path: + return {"$expr": expr} + return expr -def key_transform_is_null(self, compiler, connection): +def key_transform_is_null(self, compiler, connection, as_path=False): """ Return MQL to check the nullability of a key. @@ -139,26 +164,51 @@ def key_transform_is_null(self, compiler, connection): Reference: https://code.djangoproject.com/ticket/32252 """ - lhs_mql = process_lhs(self, compiler, connection) - rhs_mql = process_rhs(self, compiler, connection) + if as_path and self.is_simple_expression(): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + rhs_mql = process_rhs(self, compiler, connection) + return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True) # Get the root column. previous = self.lhs while isinstance(previous, KeyTransform): previous = previous.lhs root_column = previous.as_mql(compiler, connection) - return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) + expr = _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) + if as_path: + return {"$expr": expr} + return expr -def key_transform_numeric_lookup_mixin(self, compiler, connection): +def key_transform_numeric_lookup_mixin(self, compiler, connection, as_path=False): """ Return MQL to check if the field exists (i.e., is not "missing" or "null") and that the field matches the given numeric lookup expression. """ - expr = builtin_lookup(self, compiler, connection) - lhs = process_lhs(self, compiler, connection) + if as_path and self.is_simple_expression(): + return builtin_lookup(self, compiler, connection, as_path=True) + + lhs = process_lhs(self, compiler, connection, as_path=False) + expr = builtin_lookup(self, compiler, connection, as_path=False) # Check if the type of lhs is not "missing" or "null". not_missing_or_null = {"$not": {"$in": [{"$type": lhs}, ["missing", "null"]]}} - return {"$and": [expr, not_missing_or_null]} + expr = {"$and": [expr, not_missing_or_null]} + if as_path: + return {"$expr": expr} + return expr + + +def key_transform_exact(self, compiler, connection, as_path=False): + if as_path and self.is_simple_expression(): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return { + "$and": [ + builtin_lookup(self, compiler, connection, as_path=True), + _has_key_predicate(lhs_mql, None, as_path=True), + ] + } + if as_path: + return {"$expr": builtin_lookup(self, compiler, connection, as_path=False)} + return builtin_lookup(self, compiler, connection, as_path=False) def register_json_field(): @@ -173,3 +223,4 @@ def register_json_field(): KeyTransformIn.as_mql = key_transform_in KeyTransformIsNull.as_mql = key_transform_is_null KeyTransformNumericLookupMixin.as_mql = key_transform_numeric_lookup_mixin + KeyTransformExact.as_mql = key_transform_exact diff --git a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py index 6325ca4fc..80bbb8dd2 100644 --- a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py +++ b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py @@ -80,7 +80,7 @@ def _get_lookup(self, lookup_name): return lookup class EmbeddedModelArrayFieldLookups(Lookup): - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): raise ValueError( "Lookups aren't supported on PolymorphicEmbeddedModelArrayField. " "Try querying one of its embedded fields instead." diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index c45800a0a..7997ea250 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -65,9 +65,9 @@ } -def cast(self, compiler, connection): +def cast(self, compiler, connection, as_path=False): output_type = connection.data_types[self.output_field.get_internal_type()] - lhs_mql = process_lhs(self, compiler, connection)[0] + lhs_mql = process_lhs(self, compiler, connection, as_path=False)[0] if max_length := self.output_field.max_length: lhs_mql = {"$substrCP": [lhs_mql, 0, max_length]} # Skip the conversion for "object" as it doesn't need to be transformed for @@ -77,103 +77,138 @@ def cast(self, compiler, connection): lhs_mql = {"$convert": {"input": lhs_mql, "to": output_type}} if decimal_places := getattr(self.output_field, "decimal_places", None): lhs_mql = {"$trunc": [lhs_mql, decimal_places]} + + if as_path: + return {"$expr": lhs_mql} return lhs_mql -def concat(self, compiler, connection): - return self.get_source_expressions()[0].as_mql(compiler, connection) +def concat(self, compiler, connection, as_path=False): + return self.get_source_expressions()[0].as_mql(compiler, connection, as_path=as_path) -def concat_pair(self, compiler, connection): +def concat_pair(self, compiler, connection, as_path=False): # null on either side results in null for expression, wrap with coalesce. coalesced = self.coalesce() - return super(ConcatPair, coalesced).as_mql(compiler, connection) + if as_path: + return {"$expr": super(ConcatPair, coalesced).as_mql(compiler, connection, as_path=False)} + return super(ConcatPair, coalesced).as_mql(compiler, connection, as_path=False) -def cot(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) +def cot(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + if as_path: + return {"$expr": {"$divide": [1, {"$tan": lhs_mql}]}} return {"$divide": [1, {"$tan": lhs_mql}]} -def extract(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) +def extract(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) operator = EXTRACT_OPERATORS.get(self.lookup_name) if operator is None: raise NotSupportedError(f"{self.__class__.__name__} is not supported.") if timezone := self.get_tzname(): lhs_mql = {"date": lhs_mql, "timezone": timezone} - return {f"${operator}": lhs_mql} + expr = {f"${operator}": lhs_mql} + if as_path: + return {"$expr": expr} + return expr -def func(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) +def func(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) if self.function is None: raise NotSupportedError(f"{self} may need an as_mql() method.") operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) + if as_path: + return {"$expr": {f"${operator}": lhs_mql}} return {f"${operator}": lhs_mql} -def left(self, compiler, connection): - return self.get_substr().as_mql(compiler, connection) +def left(self, compiler, connection, as_path=False): + return self.get_substr().as_mql(compiler, connection, as_path=as_path) -def length(self, compiler, connection): +def length(self, compiler, connection, as_path=False): # Check for null first since $strLenCP only accepts strings. - lhs_mql = process_lhs(self, compiler, connection) - return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}} + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + expr = {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}} + if as_path: + return {"$expr": expr} + return expr -def log(self, compiler, connection): +def log(self, compiler, connection, as_path=False): # This function is usually log(base, num) but on MongoDB it's log(num, base). clone = self.copy() clone.set_source_expressions(self.get_source_expressions()[::-1]) - return func(clone, compiler, connection) + return func(clone, compiler, connection, as_path=as_path) -def now(self, compiler, connection): # noqa: ARG001 +def now(self, compiler, connection, as_path=False): # noqa: ARG001 return "$$NOW" -def null_if(self, compiler, connection): +def null_if(self, compiler, connection, as_path=False): """Return None if expr1==expr2 else expr1.""" - expr1, expr2 = (expr.as_mql(compiler, connection) for expr in self.get_source_expressions()) - return {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}} + expr1, expr2 = ( + expr.as_mql(compiler, connection, as_path=False) for expr in self.get_source_expressions() + ) + expr = {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}} + if as_path: + return {"$expr": expr} + return expr def preserve_null(operator): # If the argument is null, the function should return null, not # $toLower/Upper's behavior of returning an empty string. - def wrapped(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - return { + def wrapped(self, compiler, connection, as_path=False): + if as_path and self.is_constant_value(self.lhs): + if self.lhs is None: + return None + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return lhs_mql.upper() + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + inner_expression = { "$cond": { - "if": connection.mongo_operators["isnull"](lhs_mql, True), + "if": connection.mongo_operators_expr["isnull"](lhs_mql, True), "then": None, "else": {f"${operator}": lhs_mql}, } } + # we need to wrap this, because it will be handled in a no expression tree. + # needed in MongoDB 6. + if as_path: + return {"$expr": inner_expression} + return inner_expression return wrapped -def replace(self, compiler, connection): - expression, text, replacement = process_lhs(self, compiler, connection) +def replace(self, compiler, connection, as_path=False): + expression, text, replacement = process_lhs(self, compiler, connection, as_path=as_path) return {"$replaceAll": {"input": expression, "find": text, "replacement": replacement}} -def round_(self, compiler, connection): +def round_(self, compiler, connection, as_path=False): # noqa: ARG001 # Round needs its own function because it's a special case that inherits # from Transform but has two arguments. - return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]} + return { + "$round": [ + expr.as_mql(compiler, connection, as_path=False) + for expr in self.get_source_expressions() + ] + } -def str_index(self, compiler, connection): +def str_index(self, compiler, connection, as_path=False): # noqa: ARG001 lhs = process_lhs(self, compiler, connection) # StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB. return {"$add": [{"$indexOfCP": lhs}, 1]} -def substr(self, compiler, connection): +def substr(self, compiler, connection, as_path=False): # noqa: ARG001 lhs = process_lhs(self, compiler, connection) # The starting index is zero-indexed on MongoDB rather than one-indexed. lhs[1] = {"$add": [lhs[1], -1]} @@ -185,14 +220,14 @@ def substr(self, compiler, connection): def trim(operator): - def wrapped(self, compiler, connection): + def wrapped(self, compiler, connection, as_path=False): # noqa: ARG001 lhs = process_lhs(self, compiler, connection) return {f"${operator}": {"input": lhs}} return wrapped -def trunc(self, compiler, connection): +def trunc(self, compiler, connection, as_path=False): # noqa: ARG001 lhs_mql = process_lhs(self, compiler, connection) lhs_mql = {"date": lhs_mql, "unit": self.kind, "startOfWeek": "mon"} if timezone := self.get_tzname(): @@ -230,7 +265,7 @@ def trunc_convert_value(self, value, expression, connection): return _trunc_convert_value(self, value, expression, connection) -def trunc_date(self, compiler, connection): +def trunc_date(self, compiler, connection, **extra): # noqa: ARG001 # Cast to date rather than truncate to date. lhs_mql = process_lhs(self, compiler, connection) tzname = self.get_tzname() @@ -251,11 +286,11 @@ def trunc_date(self, compiler, connection): } -def trunc_time(self, compiler, connection): +def trunc_time(self, compiler, connection, as_path=False): # noqa: ARG001 tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncTime with tzinfo ({tzname}) isn't supported on MongoDB.") - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) return { "$dateFromString": { "dateString": { diff --git a/django_mongodb_backend/gis/lookups.py b/django_mongodb_backend/gis/lookups.py index 29c2e1e96..b81eeea9c 100644 --- a/django_mongodb_backend/gis/lookups.py +++ b/django_mongodb_backend/gis/lookups.py @@ -2,7 +2,7 @@ from django.db import NotSupportedError -def gis_lookup(self, compiler, connection): # noqa: ARG001 +def gis_lookup(self, compiler, connection, as_path=False): # noqa: ARG001 raise NotSupportedError(f"MongoDB does not support the {self.lookup_name} lookup.") diff --git a/django_mongodb_backend/indexes.py b/django_mongodb_backend/indexes.py index 99c1ef5f3..f61b3524b 100644 --- a/django_mongodb_backend/indexes.py +++ b/django_mongodb_backend/indexes.py @@ -34,7 +34,7 @@ def _get_condition_mql(self, model, schema_editor): def builtin_lookup_idx(self, compiler, connection): lhs_mql = self.lhs.target.column - value = process_rhs(self, compiler, connection) + value = process_rhs(self, compiler, connection, as_path=True) try: operator = MONGO_INDEX_OPERATORS[self.lookup_name] except KeyError: diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 8dda2bab3..85c25e229 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -11,10 +11,17 @@ from .query_utils import process_lhs, process_rhs -def builtin_lookup(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) +def builtin_lookup(self, compiler, connection, as_path=False): + if as_path and self.is_simple_expression(): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return connection.mongo_operators_match[self.lookup_name](lhs_mql, value) + value = process_rhs(self, compiler, connection) - return connection.mongo_operators[self.lookup_name](lhs_mql, value) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + if as_path: + return {"$expr": connection.mongo_operators_expr[self.lookup_name](lhs_mql, value)} + return connection.mongo_operators_expr[self.lookup_name](lhs_mql, value) _field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter @@ -33,14 +40,14 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param): return sql, sql_params -def in_(self, compiler, connection): +def in_(self, compiler, connection, **extra): db_rhs = getattr(self.rhs, "_db", None) if db_rhs is not None and db_rhs != connection.alias: raise ValueError( "Subqueries aren't allowed across different databases. Force " "the inner query to be evaluated using `list(inner_query)`." ) - return builtin_lookup(self, compiler, connection) + return builtin_lookup(self, compiler, connection, **extra) def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001 @@ -75,11 +82,16 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) ] -def is_null(self, compiler, connection): +def is_null(self, compiler, connection, as_path=False): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") - lhs_mql = process_lhs(self, compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, self.rhs) + if as_path and self.is_simple_expression(): + lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) + return connection.mongo_operators_match["isnull"](lhs_mql, self.rhs) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + if as_path: + return {"$expr": connection.mongo_operators_expr["isnull"](lhs_mql, self.rhs)} + return connection.mongo_operators_expr["isnull"](lhs_mql, self.rhs) # from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4 diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index c86b8721b..a013c3375 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -11,8 +11,6 @@ from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError -from .query_conversion.query_optimizer import convert_expr_to_match - def wrap_database_errors(func): @wraps(func) @@ -89,7 +87,8 @@ def get_pipeline(self): for query in self.subqueries or (): pipeline.extend(query.get_pipeline()) if self.match_mql: - pipeline.extend(convert_expr_to_match(self.match_mql)) + pipeline.append({"$match": self.match_mql}) + # pipeline.append({"$match": {"$expr": self.match_mql}}) if self.aggregation_pipeline: pipeline.extend(self.aggregation_pipeline) if self.project_fields: @@ -166,12 +165,13 @@ def _get_reroot_replacements(expression): for col, parent_pos in columns: target = col.target.clone() target.remote_field = col.target.remote_field - column_target = Col(compiler.collection_name, target) if parent_pos is not None: + column_target = Col(None, target) target_col = f"${parent_template}{parent_pos}" column_target.target.db_column = target_col column_target.target.set_attributes_from_name(target_col) else: + column_target = Col(compiler.collection_name, target) column_target.target = col.target replacements[col] = column_target return replacements @@ -194,7 +194,7 @@ def _get_reroot_replacements(expression): if extra: replacements = _get_reroot_replacements(extra) extra_conditions.append( - extra.replace_expressions(replacements).as_mql(compiler, connection) + extra.replace_expressions(replacements).as_mql(compiler, connection, as_path=True) ) # pushed_filter_expression is a Where expression from the outer WHERE # clause that involves fields from the joined (right-hand) table and @@ -208,9 +208,26 @@ def _get_reroot_replacements(expression): rerooted_replacement = _get_reroot_replacements(pushed_filter_expression) extra_conditions.append( pushed_filter_expression.replace_expressions(rerooted_replacement).as_mql( - compiler, connection + compiler, connection, as_path=True ) ) + + # Match the conditions: + # self.table_name.field1 = parent_table.field1 + # AND + # self.table_name.field2 = parent_table.field2 + # AND + # ... + condition = { + "$expr": { + "$and": [ + {"$eq": [f"$${parent_template}{i}", field]} for i, field in enumerate(rhs_fields) + ] + } + } + if extra_conditions: + condition = {"$and": [condition, *extra_conditions]} + lookup_pipeline = [ { "$lookup": { @@ -222,25 +239,7 @@ def _get_reroot_replacements(expression): f"{parent_template}{i}": parent_field for i, parent_field in enumerate(lhs_fields) }, - "pipeline": [ - { - # Match the conditions: - # self.table_name.field1 = parent_table.field1 - # AND - # self.table_name.field2 = parent_table.field2 - # AND - # ... - "$match": { - "$expr": { - "$and": [ - {"$eq": [f"$${parent_template}{i}", field]} - for i, field in enumerate(rhs_fields) - ] - + extra_conditions - } - } - } - ], + "pipeline": [{"$match": condition}], # Rename the output as table_alias. "as": self.table_alias, } @@ -274,7 +273,7 @@ def _get_reroot_replacements(expression): return lookup_pipeline -def where_node(self, compiler, connection): +def where_node(self, compiler, connection, as_path=False): if self.connector == AND: full_needed, empty_needed = len(self.children), 1 else: @@ -297,14 +296,16 @@ def where_node(self, compiler, connection): if len(self.children) > 2: rhs_sum = Mod(rhs_sum, 2) rhs = Exact(1, rhs_sum) - return self.__class__([lhs, rhs], AND, self.negated).as_mql(compiler, connection) + return self.__class__([lhs, rhs], AND, self.negated).as_mql( + compiler, connection, as_path=as_path + ) else: operator = "$or" children_mql = [] for child in self.children: try: - mql = child.as_mql(compiler, connection) + mql = child.as_mql(compiler, connection, as_path=as_path) except EmptyResultSet: empty_needed -= 1 except FullResultSet: @@ -331,13 +332,17 @@ def where_node(self, compiler, connection): raise FullResultSet if self.negated and mql: - mql = {"$not": mql} + mql = {"$nor": [mql]} if as_path else {"$not": [mql]} return mql +def nothing_node(self, compiler, connection, as_path=None): # noqa: ARG001 + return self.as_sql(compiler, connection) + + def register_nodes(): ExtraWhere.as_mql = extra_where Join.as_mql = join - NothingNode.as_mql = NothingNode.as_sql + NothingNode.as_mql = nothing_node WhereNode.as_mql = where_node diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 4b744241e..d5943f61c 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -7,7 +7,7 @@ def is_direct_value(node): return not hasattr(node, "as_sql") -def process_lhs(node, compiler, connection): +def process_lhs(node, compiler, connection, **extra): if not hasattr(node, "lhs"): # node is a Func or Expression, possibly with multiple source expressions. result = [] @@ -15,27 +15,30 @@ def process_lhs(node, compiler, connection): if expr is None: continue try: - result.append(expr.as_mql(compiler, connection)) + result.append(expr.as_mql(compiler, connection, **extra)) except FullResultSet: - result.append(Value(True).as_mql(compiler, connection)) + result.append(Value(True).as_mql(compiler, connection, **extra)) if isinstance(node, Aggregate): return result[0] return result # node is a Transform with just one source expression, aliased as "lhs". if is_direct_value(node.lhs): return node - return node.lhs.as_mql(compiler, connection) + return node.lhs.as_mql(compiler, connection, **extra) -def process_rhs(node, compiler, connection): +def process_rhs(node, compiler, connection, as_path=False): rhs = node.rhs if hasattr(rhs, "as_mql"): if getattr(rhs, "subquery", False) and hasattr(node, "get_subquery_wrapping_pipeline"): value = rhs.as_mql( - compiler, connection, get_wrapping_pipeline=node.get_subquery_wrapping_pipeline + compiler, + connection, + get_wrapping_pipeline=node.get_subquery_wrapping_pipeline, + as_path=as_path, ) else: - value = rhs.as_mql(compiler, connection) + value = rhs.as_mql(compiler, connection, as_path=as_path) else: _, value = node.process_rhs(compiler, connection) lookup_name = node.lookup_name @@ -47,7 +50,13 @@ def process_rhs(node, compiler, connection): return value -def regex_match(field, regex_vals, insensitive=False): +def regex_expr(field, regex_vals, insensitive=False): regex = {"$concat": regex_vals} if isinstance(regex_vals, tuple) else regex_vals options = "i" if insensitive else "" return {"$regexMatch": {"input": field, "regex": regex, "options": options}} + + +def regex_match(field, regex, insensitive=False): + options = "i" if insensitive else "" + # return {"$regexMatch": {"input": field, "regex": regex, "options": options}} + return {field: {"$regex": regex, "$options": options}} diff --git a/tests/lookup_/tests.py b/tests/lookup_/tests.py index 6fce89942..d6e5da2a6 100644 --- a/tests/lookup_/tests.py +++ b/tests/lookup_/tests.py @@ -29,15 +29,7 @@ def test_mql(self): self.assertAggregateQuery( query, "lookup__book", - [ - { - "$match": { - "$expr": { - "$regexMatch": {"input": "$title", "regex": "Moby Dick", "options": ""} - } - } - } - ], + [{"$match": {"title": {"$regex": "Moby Dick", "$options": ""}}}], ) diff --git a/tests/queries_/test_mql.py b/tests/queries_/test_mql.py index ffd1e2e32..e8837bf8a 100644 --- a/tests/queries_/test_mql.py +++ b/tests/queries_/test_mql.py @@ -11,9 +11,7 @@ class MQLTests(MongoTestCaseMixin, TestCase): def test_all(self): with self.assertNumQueries(1) as ctx: list(Author.objects.all()) - self.assertAggregateQuery( - ctx.captured_queries[0]["sql"], "queries__author", [{"$match": {}}] - ) + self.assertAggregateQuery(ctx.captured_queries[0]["sql"], "queries__author", []) def test_join(self): with self.assertNumQueries(1) as ctx: @@ -29,12 +27,14 @@ def test_join(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Bob"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Bob"}, + ] } } ], @@ -62,12 +62,14 @@ def test_filter_on_local_and_related_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "John"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "John"}, + ] } } ], @@ -123,22 +125,19 @@ def test_filter_on_self_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - { - "$and": [ - { - "$eq": [ - "$group_id", - ObjectId("6891ff7822e475eddc20f159"), - ] - }, - {"$eq": ["$name", "parent"]}, - ] - }, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + { + "$and": [ + {"group_id": ObjectId("6891ff7822e475eddc20f159")}, + {"name": "parent"}, + ] + }, + ] } } ], @@ -171,17 +170,16 @@ def test_filter_on_reverse_foreignkey_relation(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - { - "$eq": [ - "$status", - ObjectId("6891ff7822e475eddc20f159"), + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} ] - }, - ] - } + } + }, + {"status": ObjectId("6891ff7822e475eddc20f159")}, + ] } } ], @@ -215,17 +213,16 @@ def test_filter_on_local_and_nested_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - { - "$eq": [ - "$status", - ObjectId("6891ff7822e475eddc20f159"), + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} ] - }, - ] - } + } + }, + {"status": ObjectId("6891ff7822e475eddc20f159")}, + ] } } ], @@ -240,12 +237,14 @@ def test_filter_on_local_and_nested_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "My Order"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "My Order"}, + ] } } ], @@ -276,6 +275,7 @@ def test_negated_related_filter_is_not_pushable(self): [ { "$lookup": { + "as": "queries__author", "from": "queries__author", "let": {"parent__field__0": "$author_id"}, "pipeline": [ @@ -285,11 +285,10 @@ def test_negated_related_filter_is_not_pushable(self): } } ], - "as": "queries__author", } }, {"$unwind": "$queries__author"}, - {"$match": {"$expr": {"$not": {"$eq": ["$queries__author.name", "John"]}}}}, + {"$match": {"$nor": [{"queries__author.name": "John"}]}}, ], ) @@ -341,21 +340,25 @@ def test_push_equality_between_parent_and_child_fields(self): [ { "$lookup": { + "as": "queries__orderitem", "from": "queries__orderitem", "let": {"parent__field__0": "$_id", "parent__field__1": "$_id"}, "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - {"$eq": ["$status", "$$parent__field__1"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} + ] + } + }, + {"$expr": {"$eq": ["$status", "$$parent__field__1"]}}, + ] } } ], - "as": "queries__orderitem", } }, {"$unwind": "$queries__orderitem"}, @@ -398,12 +401,14 @@ def test_simple_related_filter_is_pushed(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Alice"}, + ] } } ], @@ -416,6 +421,7 @@ def test_simple_related_filter_is_pushed(self): ) def test_subquery_join_is_pushed(self): + # TODO; isn't fully OPTIMIZED with self.assertNumQueries(1) as ctx: list(Library.objects.filter(~models.Q(readers__name="Alice"))) @@ -436,12 +442,21 @@ def test_subquery_join_is_pushed(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [ + { + "$eq": [ + "$$parent__field__0", + "$_id", + ] + } + ] + } + }, + {"name": "Alice"}, + ] } } ], @@ -480,21 +495,28 @@ def test_subquery_join_is_pushed(self): }, { "$match": { - "$expr": { - "$not": { - "$eq": [ - { - "$not": { - "$or": [ - {"$eq": [{"$type": "$__subquery0.a"}, "missing"]}, - {"$eq": ["$__subquery0.a", None]}, - ] - } - }, - True, - ] + "$nor": [ + { + "$expr": { + "$eq": [ + { + "$not": { + "$or": [ + { + "$eq": [ + {"$type": "$__subquery0.a"}, + "missing", + ] + }, + {"$eq": ["$__subquery0.a", None]}, + ] + } + }, + True, + ] + } } - } + ] } }, ], @@ -531,12 +553,14 @@ def test_filter_on_local_and_related_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Alice"}, + ] } } ], From 3d7bdf3da006b0c504393820f83f6d685d42fee7 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 19 Sep 2025 22:20:55 -0300 Subject: [PATCH 02/23] Object-oriented approach solution --- django_mongodb_backend/aggregates.py | 23 +-- django_mongodb_backend/compiler.py | 2 +- .../expressions/builtins.py | 105 +++++------- django_mongodb_backend/expressions/search.py | 15 +- django_mongodb_backend/fields/array.py | 80 ++++------ .../fields/embedded_model.py | 29 +++- .../fields/embedded_model_array.py | 33 +++- django_mongodb_backend/fields/json.py | 135 ++++++++++------ django_mongodb_backend/functions.py | 149 +++++++++--------- django_mongodb_backend/lookups.py | 62 +++++--- django_mongodb_backend/query.py | 1 + django_mongodb_backend/query_utils.py | 45 +++++- .../test_embedded_model_array.py | 4 +- 13 files changed, 373 insertions(+), 310 deletions(-) diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index 31f4b29ba..67ccbb499 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -8,14 +8,7 @@ MONGO_AGGREGATIONS = {Count: "sum"} -def aggregate( - self, - compiler, - connection, - operator=None, - resolve_inner_expression=False, - **extra_context, # noqa: ARG001 -): +def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False): if self.filter: node = self.copy() node.filter = None @@ -31,7 +24,7 @@ def aggregate( return {f"${operator}": lhs_mql} -def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001 +def count(self, compiler, connection, resolve_inner_expression=False): """ When resolve_inner_expression=True, return the MQL that resolves as a value. This is used to count different elements, so the inner values are @@ -64,16 +57,16 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co return {"$add": [{"$size": lhs_mql}, exits_null]} -def stddev_variance(self, compiler, connection, **extra_context): +def stddev_variance(self, compiler, connection): if self.function.endswith("_SAMP"): operator = "stdDevSamp" elif self.function.endswith("_POP"): operator = "stdDevPop" - return aggregate(self, compiler, connection, operator=operator, **extra_context) + return aggregate(self, compiler, connection, operator=operator) def register_aggregates(): - Aggregate.as_mql = aggregate - Count.as_mql = count - StdDev.as_mql = stddev_variance - Variance.as_mql = stddev_variance + Aggregate.as_mql_expr = aggregate + Count.as_mql_expr = count + StdDev.as_mql_expr = stddev_variance + Variance.as_mql_expr = stddev_variance diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 32f190a3b..2eb6dfd99 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -709,7 +709,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False # For brevity/simplicity, project {"field_name": 1} # instead of {"field_name": "$field_name"}. if isinstance(expr, Col) and name == expr.target.column and not force_expression - else expr.as_mql(self, self.connection, as_path=False) + else expr.as_mql(self, self.connection) ) except EmptyResultSet: empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 212b22e12..4255e1fe3 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -14,7 +14,6 @@ Exists, ExpressionList, ExpressionWrapper, - Func, NegatedExpression, OrderBy, RawSQL, @@ -25,15 +24,12 @@ Value, When, ) -from django.db.models.fields.json import KeyTransform from django.db.models.sql import Query -from django_mongodb_backend.fields.array import Array +from ..query_utils import process_lhs -from ..query_utils import is_direct_value, process_lhs - -def case(self, compiler, connection, as_path=False): +def case(self, compiler, connection): case_parts = [] for case in self.cases: case_mql = {} @@ -50,16 +46,12 @@ def case(self, compiler, connection, as_path=False): default_mql = self.default.as_mql(compiler, connection) if not case_parts: return default_mql - expr = { + return { "$switch": { "branches": case_parts, "default": default_mql, } } - if as_path: - return {"$expr": expr} - - return expr def col(self, compiler, connection, as_path=False): # noqa: ARG001 @@ -100,12 +92,12 @@ def combined_expression(self, compiler, connection, as_path=False): return connection.ops.combine_expression(self.connector, expressions) -def expression_wrapper(self, compiler, connection, as_path=False): - return self.expression.as_mql(compiler, connection, as_path=as_path) +def expression_wrapper_expr(self, compiler, connection): + return self.expression.as_mql(compiler, connection, as_path=False) -def negated_expression(self, compiler, connection, as_path=False): - return {"$not": expression_wrapper(self, compiler, connection, as_path=as_path)} +def negated_expression_expr(self, compiler, connection): + return {"$not": expression_wrapper_expr(self, compiler, connection)} def order_by(self, compiler, connection): @@ -178,32 +170,26 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001 return f"{prefix}{refs}" -def star(self, compiler, connection, **extra): # noqa: ARG001 +@property +def ref_is_simple_column(self): + return isinstance(self.source, Col) and self.source.alias is not None + + +def star(self, compiler, connection, as_path=False): # noqa: ARG001 return {"$literal": True} -def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_path=False): - expr = self.query.as_mql( +def subquery(self, compiler, connection, get_wrapping_pipeline=None): + return self.query.as_mql( compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False ) - if as_path: - return {"$expr": expr} - return expr -def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False): +def exists(self, compiler, connection, get_wrapping_pipeline=None): try: - lhs_mql = subquery( - self, - compiler, - connection, - get_wrapping_pipeline=get_wrapping_pipeline, - as_path=as_path, - ) + lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) except EmptyResultSet: return Value(False).as_mql(compiler, connection) - if as_path: - return {"$expr": connection.mongo_operators_match["isnull"](lhs_mql, False)} return connection.mongo_operators_expr["isnull"](lhs_mql, False) @@ -235,54 +221,37 @@ def value(self, compiler, connection, as_path=False): # noqa: ARG001 return value -@staticmethod -def _is_constant_value(value): - if isinstance(value, list | Array): - iterable = value.get_source_expressions() if isinstance(value, Array) else value - return all(_is_constant_value(e) for e in iterable) - if is_direct_value(value): - return True - return isinstance(value, Func | Value) and not ( - value.contains_aggregate - or value.contains_over_clause - or value.contains_column_references - or value.contains_subquery - ) - - -@staticmethod -def _is_simple_column(lhs): - while isinstance(lhs, KeyTransform): - if "." in getattr(lhs, "key_name", ""): - return False - lhs = lhs.lhs - col = lhs.source if isinstance(lhs, Ref) else lhs - # Foreign columns from parent cannot be addressed as single match - return isinstance(col, Col) and col.alias is not None - +def base_expression(self, compiler, connection, as_path=False, **extra): + if ( + as_path + and hasattr(self, "as_mql_path") + and getattr(self, "is_simple_expression", lambda: False)() + ): + return self.as_mql_path(compiler, connection, **extra) -def _is_simple_expression(self): - return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs) + expr = self.as_mql_expr(compiler, connection, **extra) + return {"$expr": expr} if as_path else expr def register_expressions(): - Case.as_mql = case + BaseExpression.as_mql = base_expression + BaseExpression.is_simple_column = False + Case.as_mql_expr = case Col.as_mql = col + Col.is_simple_column = True ColPairs.as_mql = col_pairs - CombinedExpression.as_mql = combined_expression - Exists.as_mql = exists + CombinedExpression.as_mql_expr = combined_expression + Exists.as_mql_expr = exists ExpressionList.as_mql = process_lhs - ExpressionWrapper.as_mql = expression_wrapper - NegatedExpression.as_mql = negated_expression - OrderBy.as_mql = order_by + ExpressionWrapper.as_mql_expr = expression_wrapper_expr + NegatedExpression.as_mql_expr = negated_expression_expr + OrderBy.as_mql_expr = order_by Query.as_mql = query RawSQL.as_mql = raw_sql Ref.as_mql = ref + Ref.is_simple_column = ref_is_simple_column ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql Star.as_mql = star - Subquery.as_mql = subquery + Subquery.as_mql_expr = subquery When.as_mql = when Value.as_mql = value - BaseExpression.is_simple_expression = _is_simple_expression - BaseExpression.is_simple_column = _is_simple_column - BaseExpression.is_constant_value = _is_constant_value diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 5b74d9232..aba5e0cfa 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -933,12 +933,15 @@ def __str__(self): def __repr__(self): return f"SearchText({self.lhs}, {self.rhs})" - def as_mql(self, compiler, connection, as_path=False): - lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) - value = process_rhs(self, compiler, connection, as_path=as_path) - if as_path: - return {lhs_mql: {"$gte": value}} - return {"$expr": {"$gte": [lhs_mql, value]}} + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + value = process_rhs(self, compiler, connection, as_path=False) + return {"$gte": [lhs_mql, value]} + + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return {lhs_mql: {"$gte": value}} CharField.register_lookup(SearchTextLookup) diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 66149c1dc..4d2a42cb5 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -230,9 +230,9 @@ def formfield(self, **kwargs): class Array(Func): - def as_mql(self, compiler, connection, as_path=False): + def as_mql_expr(self, compiler, connection): return [ - expr.as_mql(compiler, connection, as_path=as_path) + expr.as_mql(compiler, connection, as_path=False) for expr in self.get_source_expressions() ] @@ -254,24 +254,16 @@ def __init__(self, lhs, rhs): class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contains" - def as_mql(self, compiler, connection, as_path=False): - if as_path and self.is_simple_expression(): - lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) - value = process_rhs(self, compiler, connection, as_path=as_path) - if value is None: - return False - return {lhs_mql: {"$all": value}} + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + if value is None: + return False + return {lhs_mql: {"$all": value}} + + def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) value = process_rhs(self, compiler, connection, as_path=False) - expr = { - "$and": [ - {"$ne": [lhs_mql, None]}, - {"$ne": [value, None]}, - {"$setIsSubset": [value, lhs_mql]}, - ] - } - if as_path: - return {"$expr": expr} return { "$and": [ {"$ne": [lhs_mql, None]}, @@ -285,19 +277,16 @@ def as_mql(self, compiler, connection, as_path=False): class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contained_by" - def as_mql(self, compiler, connection, as_path=False): + def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) value = process_rhs(self, compiler, connection, as_path=False) - expr = { + return { "$and": [ {"$ne": [lhs_mql, None]}, {"$ne": [value, None]}, {"$setIsSubset": [lhs_mql, value]}, ] } - if as_path: - return {"$expr": expr} - return expr @ArrayField.register_lookup @@ -344,23 +333,20 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) }, ] - def as_mql(self, compiler, connection, as_path=False): - if as_path and self.is_simple_expression(): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - value = process_rhs(self, compiler, connection, as_path=True) - return {lhs_mql: {"$in": value}} + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return {lhs_mql: {"$in": value}} + def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) value = process_rhs(self, compiler, connection, as_path=False) - expr = { + return { "$and": [ {"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}, ] } - if as_path: - return {"$expr": expr} - return expr @ArrayField.register_lookup @@ -368,12 +354,9 @@ class ArrayLenTransform(Transform): lookup_name = "len" output_field = IntegerField() - def as_mql(self, compiler, connection, as_path=False): + def as_mql_expr(self, compiler, connection, as_path=False): lhs_mql = process_lhs(self, compiler, connection, as_path=False) - expr = {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} - if as_path: - return {"$expr": expr} - return expr + return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} @ArrayField.register_lookup @@ -398,15 +381,20 @@ def __init__(self, index, base_field, *args, **kwargs): self.index = index self.base_field = base_field - def as_mql(self, compiler, connection, as_path=False): - if as_path and self.is_simple_column(self.lhs): - lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) - return f"{lhs_mql}.{self.index}" + def is_simple_expression(self): + return self.is_simple_column + + @property + def is_simple_column(self): + return self.lhs.is_simple_column + + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return f"{lhs_mql}.{self.index}" + + def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) - expr = {"$arrayElemAt": [lhs_mql, self.index]} - if as_path: - return {"$expr": expr} - return expr + return {"$arrayElemAt": [lhs_mql, self.index]} @property def output_field(self): @@ -428,7 +416,7 @@ def __init__(self, start, end, *args, **kwargs): self.start = start self.end = end - def as_mql(self, compiler, connection, as_path=False): + def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) return {"$slice": [lhs_mql, self.start, self.end]} diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index fbc1d53a1..a6f06e096 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -5,6 +5,7 @@ from django.db import models from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Transform +from django.utils.functional import cached_property from .. import forms @@ -166,6 +167,18 @@ def __init__(self, field, *args, **kwargs): def get_lookup(self, name): return self.field.get_lookup(name) + def is_simple_expression(self): + return self.is_simple_column + + @cached_property + def is_simple_column(self): + previous = self + while isinstance(previous, KeyTransform): + if not previous.key_name.isalnum(): + return False + previous = previous.lhs + return previous.is_simple_column + def get_transform(self, name): """ Validate that `name` is either a field of an embedded model or a @@ -185,16 +198,22 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection, as_path=False): + def as_mql_path(self, compiler, connection): + previous = self + key_transforms = [] + while isinstance(previous, EmbeddedModelTransform): + key_transforms.insert(0, previous.key_name) + previous = previous.lhs + mql = previous.as_mql(compiler, connection, as_path=True) + mql_path = ".".join(key_transforms) + return f"{mql}.{mql_path}" + + def as_mql_expr(self, compiler, connection): previous = self columns = [] while isinstance(previous, EmbeddedModelTransform): columns.insert(0, previous.field.column) previous = previous.lhs - if as_path: - mql = previous.as_mql(compiler, connection, as_path=True) - mql_path = ".".join(columns) - return f"{mql}.{mql_path}" mql = previous.as_mql(compiler, connection) for column in columns: mql = {"$getField": {"input": mql, "field": column}} diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index fc0872be9..958359642 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -5,6 +5,7 @@ from django.db.models.expressions import Col from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Lookup, Transform +from django.utils.functional import cached_property from .. import forms from ..query_utils import process_lhs, process_rhs @@ -129,7 +130,7 @@ def process_rhs(self, compiler, connection, as_path=False): for v in value ] - def as_mql(self, compiler, connection, as_path=False): + def as_mql_expr(self, compiler, connection): # Querying a subfield within the array elements (via nested # KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over # the array and applying $in on the subfield. @@ -139,7 +140,7 @@ def as_mql(self, compiler, connection, as_path=False): lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators_expr[self.lookup_name]( inner_lhs_mql, values ) - return {"$expr": {"$anyElementTrue": lhs_mql}} + return {"$anyElementTrue": lhs_mql} @_EmbeddedModelArrayOutputField.register_lookup @@ -226,6 +227,7 @@ class EmbeddedModelArrayFieldLessThanOrEqual( class EmbeddedModelArrayFieldTransform(Transform): field_class_name = "EmbeddedModelArrayField" + PREFIX_ITERABLE = "item" def __init__(self, field, *args, **kwargs): super().__init__(*args, **kwargs) @@ -242,6 +244,18 @@ def __call__(self, this, *args, **kwargs): self._lhs = self._sub_transform(self._lhs, *args, **kwargs) return self + def is_simple_expression(self): + return self.is_simple_column + + @cached_property + def is_simple_column(self): + previous = self + while isinstance(previous, EmbeddedModelArrayFieldTransform): + if not previous.key_name.isalnum(): + return False + previous = previous.lhs + return previous.is_simple_column and self._lhs.is_simple_column + def get_lookup(self, name): return self.output_field.get_lookup(name) @@ -274,11 +288,14 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection, as_path=False): - if as_path: - inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True) - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - return f"{inner_lhs_mql}.{lhs_mql}" + def as_mql_path(self, compiler, connection): + inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True).removeprefix( + f"${self.PREFIX_ITERABLE}." + ) + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return f"{lhs_mql}.{inner_lhs_mql}" + + def as_mql_expr(self, compiler, connection): inner_lhs_mql = self._lhs.as_mql(compiler, connection) lhs_mql = process_lhs(self, compiler, connection) return { @@ -286,7 +303,7 @@ def as_mql(self, compiler, connection, as_path=False): { "$map": { "input": lhs_mql, - "as": "item", + "as": self.PREFIX_ITERABLE, "in": inner_lhs_mql, } }, diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index dd582ae8c..f9a80999c 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -1,3 +1,4 @@ +from functools import partialmethod from itertools import chain from django.db import NotSupportedError @@ -16,7 +17,7 @@ KeyTransformNumericLookupMixin, ) -from ..lookups import builtin_lookup +from ..lookups import builtin_lookup_expr, builtin_lookup_path from ..query_utils import process_lhs, process_rhs @@ -71,17 +72,20 @@ def _has_key_predicate(path, root_column=None, negated=False, as_path=False): return result +def has_key_check_simple_expression(self): + rhs = [self.rhs] if not isinstance(self.rhs, (list, tuple)) else self.rhs + return self.is_simple_column and all(key.isalnum() for key in rhs) + + def has_key_lookup(self, compiler, connection, as_path=False): """Return MQL to check for the existence of a key.""" rhs = self.rhs if not isinstance(rhs, (list, tuple)): rhs = [rhs] - as_path = as_path and self.is_simple_expression() and all("." not in v for v in rhs) lhs = process_lhs(self, compiler, connection, as_path=as_path) paths = [] # Transform any "raw" keys into KeyTransforms to allow consistent handling # in the code that follows. - for key in rhs: rhs_json_path = key if isinstance(key, KeyTransform) else KeyTransform(key, self.lhs) paths.append(rhs_json_path.as_mql(compiler, connection, as_path=as_path)) @@ -89,10 +93,7 @@ def has_key_lookup(self, compiler, connection, as_path=False): for path in paths: keys.append(_has_key_predicate(path, lhs, as_path=as_path)) - result = keys[0] if self.mongo_operator is None else {self.mongo_operator: keys} - if not as_path: - result = {"$expr": result} - return result + return keys[0] if self.mongo_operator is None else {self.mongo_operator: keys} _process_rhs = JSONExact.process_rhs @@ -121,14 +122,9 @@ def key_transform(self, compiler, connection, as_path=False): while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs - if as_path and self.is_simple_column(self.lhs): - lhs_mql = previous.as_mql(compiler, connection, as_path=True) - return build_json_mql_path(lhs_mql, key_transforms, as_path=True) # Collect all key transforms in order. - lhs_mql = previous.as_mql(compiler, connection, as_path=False) - if as_path: - return {"$expr": build_json_mql_path(lhs_mql, key_transforms, as_path=False)} - return build_json_mql_path(lhs_mql, key_transforms, as_path=False) + lhs_mql = previous.as_mql(compiler, connection, as_path=as_path) + return build_json_mql_path(lhs_mql, key_transforms, as_path=as_path) def key_transform_in(self, compiler, connection, as_path=False): @@ -137,7 +133,7 @@ def key_transform_in(self, compiler, connection, as_path=False): set of specified values (rhs). """ if as_path and self.is_simple_expression(): - return builtin_lookup(self, compiler, connection, as_path=True) + return builtin_lookup_path(self, compiler, connection) lhs_mql = process_lhs(self, compiler, connection) # Traverse to the root column. @@ -154,7 +150,24 @@ def key_transform_in(self, compiler, connection, as_path=False): return expr -def key_transform_is_null(self, compiler, connection, as_path=False): +def key_transform_in_path(self, compiler, connection): + return builtin_lookup_path(self, compiler, connection) + + +def key_transform_in_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + # Traverse to the root column. + previous = self.lhs + while isinstance(previous, KeyTransform): + previous = previous.lhs + root_column = previous.as_mql(compiler, connection) + value = process_rhs(self, compiler, connection) + # Construct the expression to check if lhs_mql values are in rhs values. + expr = connection.mongo_operators_expr[self.lookup_name](lhs_mql, value) + return {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} + + +def key_transform_is_null_path(self, compiler, connection): """ Return MQL to check the nullability of a key. @@ -164,51 +177,62 @@ def key_transform_is_null(self, compiler, connection, as_path=False): Reference: https://code.djangoproject.com/ticket/32252 """ - if as_path and self.is_simple_expression(): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - rhs_mql = process_rhs(self, compiler, connection) - return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True) - # Get the root column. + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + rhs_mql = process_rhs(self, compiler, connection, as_path=True) + return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True) + + +def key_transform_is_null_expr(self, compiler, connection): previous = self.lhs while isinstance(previous, KeyTransform): previous = previous.lhs root_column = previous.as_mql(compiler, connection) - expr = _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) - if as_path: - return {"$expr": expr} - return expr + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + rhs_mql = process_rhs(self, compiler, connection) + return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) -def key_transform_numeric_lookup_mixin(self, compiler, connection, as_path=False): +def key_transform_numeric_lookup_mixin_path(self, compiler, connection): """ Return MQL to check if the field exists (i.e., is not "missing" or "null") and that the field matches the given numeric lookup expression. """ - if as_path and self.is_simple_expression(): - return builtin_lookup(self, compiler, connection, as_path=True) + return builtin_lookup_path(self, compiler, connection) + +def key_transform_numeric_lookup_mixin_expr(self, compiler, connection): + """ + Return MQL to check if the field exists (i.e., is not "missing" or "null") + and that the field matches the given numeric lookup expression. + """ lhs = process_lhs(self, compiler, connection, as_path=False) - expr = builtin_lookup(self, compiler, connection, as_path=False) + expr = builtin_lookup_expr(self, compiler, connection) # Check if the type of lhs is not "missing" or "null". not_missing_or_null = {"$not": {"$in": [{"$type": lhs}, ["missing", "null"]]}} - expr = {"$and": [expr, not_missing_or_null]} - if as_path: - return {"$expr": expr} - return expr + return {"$and": [expr, not_missing_or_null]} -def key_transform_exact(self, compiler, connection, as_path=False): - if as_path and self.is_simple_expression(): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - return { - "$and": [ - builtin_lookup(self, compiler, connection, as_path=True), - _has_key_predicate(lhs_mql, None, as_path=True), - ] - } - if as_path: - return {"$expr": builtin_lookup(self, compiler, connection, as_path=False)} - return builtin_lookup(self, compiler, connection, as_path=False) +def key_transform_exact_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return { + "$and": [ + builtin_lookup_path(self, compiler, connection), + _has_key_predicate(lhs_mql, None, as_path=True), + ] + } + + +def key_transform_exact_expr(self, compiler, connection): + return builtin_lookup_expr(self, compiler, connection) + + +def keytransform_is_simple_column(self): + previous = self + while isinstance(previous, KeyTransform): + if not previous.key_name.isalnum(): + return False + previous = previous.lhs + return previous.is_simple_column def register_json_field(): @@ -216,11 +240,20 @@ def register_json_field(): DataContains.as_mql = data_contains HasAnyKeys.mongo_operator = "$or" HasKey.mongo_operator = None - HasKeyLookup.as_mql = has_key_lookup + HasKeyLookup.as_mql_path = partialmethod(has_key_lookup, as_path=True) + HasKeyLookup.as_mql_expr = partialmethod(has_key_lookup, as_path=False) + HasKeyLookup.is_simple_expression = has_key_check_simple_expression HasKeys.mongo_operator = "$and" JSONExact.process_rhs = json_exact_process_rhs - KeyTransform.as_mql = key_transform - KeyTransformIn.as_mql = key_transform_in - KeyTransformIsNull.as_mql = key_transform_is_null - KeyTransformNumericLookupMixin.as_mql = key_transform_numeric_lookup_mixin - KeyTransformExact.as_mql = key_transform_exact + KeyTransform.is_simple_column = property(keytransform_is_simple_column) + KeyTransform.is_simple_expression = keytransform_is_simple_column + KeyTransform.as_mql_path = partialmethod(key_transform, as_path=True) + KeyTransform.as_mql_expr = partialmethod(key_transform, as_path=False) + KeyTransformIn.as_mql_path = key_transform_in_path + KeyTransformIn.as_mql_expr = key_transform_in_expr + KeyTransformIsNull.as_mql_path = key_transform_is_null_path + KeyTransformIsNull.as_mql_expr = key_transform_is_null_expr + KeyTransformNumericLookupMixin.as_mql_path = key_transform_numeric_lookup_mixin_path + KeyTransformNumericLookupMixin.as_mql_expr = key_transform_numeric_lookup_mixin_expr + KeyTransformExact.as_mql_expr = key_transform_exact_expr + KeyTransformExact.as_mql_path = key_transform_exact_path diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index 7997ea250..a2c18a986 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -65,7 +65,7 @@ } -def cast(self, compiler, connection, as_path=False): +def cast(self, compiler, connection): output_type = connection.data_types[self.output_field.get_internal_type()] lhs_mql = process_lhs(self, compiler, connection, as_path=False)[0] if max_length := self.output_field.max_length: @@ -78,27 +78,21 @@ def cast(self, compiler, connection, as_path=False): if decimal_places := getattr(self.output_field, "decimal_places", None): lhs_mql = {"$trunc": [lhs_mql, decimal_places]} - if as_path: - return {"$expr": lhs_mql} return lhs_mql -def concat(self, compiler, connection, as_path=False): - return self.get_source_expressions()[0].as_mql(compiler, connection, as_path=as_path) +def concat(self, compiler, connection): + return self.get_source_expressions()[0].as_mql(compiler, connection, as_path=False) -def concat_pair(self, compiler, connection, as_path=False): +def concat_pair(self, compiler, connection): # null on either side results in null for expression, wrap with coalesce. coalesced = self.coalesce() - if as_path: - return {"$expr": super(ConcatPair, coalesced).as_mql(compiler, connection, as_path=False)} - return super(ConcatPair, coalesced).as_mql(compiler, connection, as_path=False) + return super(ConcatPair, coalesced).as_mql_expr(compiler, connection) -def cot(self, compiler, connection, as_path=False): +def cot(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) - if as_path: - return {"$expr": {"$divide": [1, {"$tan": lhs_mql}]}} return {"$divide": [1, {"$tan": lhs_mql}]} @@ -109,10 +103,7 @@ def extract(self, compiler, connection, as_path=False): raise NotSupportedError(f"{self.__class__.__name__} is not supported.") if timezone := self.get_tzname(): lhs_mql = {"date": lhs_mql, "timezone": timezone} - expr = {f"${operator}": lhs_mql} - if as_path: - return {"$expr": expr} - return expr + return {f"${operator}": lhs_mql} def func(self, compiler, connection, as_path=False): @@ -125,73 +116,69 @@ def func(self, compiler, connection, as_path=False): return {f"${operator}": lhs_mql} -def left(self, compiler, connection, as_path=False): - return self.get_substr().as_mql(compiler, connection, as_path=as_path) +def func_path(self, compiler, connection): # noqa: ARG001 + raise NotSupportedError(f"{self} May need an as_mql_path() method.") + +def func_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + if self.function is None: + raise NotSupportedError(f"{self} may need an as_mql() method.") + operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) + return {f"${operator}": lhs_mql} -def length(self, compiler, connection, as_path=False): + +def left(self, compiler, connection): + return self.get_substr().as_mql(compiler, connection, as_path=False) + + +def length(self, compiler, connection): # Check for null first since $strLenCP only accepts strings. lhs_mql = process_lhs(self, compiler, connection, as_path=False) - expr = {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}} - if as_path: - return {"$expr": expr} - return expr + return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}} -def log(self, compiler, connection, as_path=False): +def log(self, compiler, connection): # This function is usually log(base, num) but on MongoDB it's log(num, base). clone = self.copy() clone.set_source_expressions(self.get_source_expressions()[::-1]) - return func(clone, compiler, connection, as_path=as_path) + return func(clone, compiler, connection, as_path=False) -def now(self, compiler, connection, as_path=False): # noqa: ARG001 +def now(self, compiler, connection): # noqa: ARG001 return "$$NOW" -def null_if(self, compiler, connection, as_path=False): +def null_if(self, compiler, connection): """Return None if expr1==expr2 else expr1.""" expr1, expr2 = ( expr.as_mql(compiler, connection, as_path=False) for expr in self.get_source_expressions() ) - expr = {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}} - if as_path: - return {"$expr": expr} - return expr + return {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}} def preserve_null(operator): # If the argument is null, the function should return null, not # $toLower/Upper's behavior of returning an empty string. - def wrapped(self, compiler, connection, as_path=False): - if as_path and self.is_constant_value(self.lhs): - if self.lhs is None: - return None - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - return lhs_mql.upper() + def wrapped(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) - inner_expression = { + return { "$cond": { "if": connection.mongo_operators_expr["isnull"](lhs_mql, True), "then": None, "else": {f"${operator}": lhs_mql}, } } - # we need to wrap this, because it will be handled in a no expression tree. - # needed in MongoDB 6. - if as_path: - return {"$expr": inner_expression} - return inner_expression return wrapped -def replace(self, compiler, connection, as_path=False): - expression, text, replacement = process_lhs(self, compiler, connection, as_path=as_path) +def replace(self, compiler, connection): + expression, text, replacement = process_lhs(self, compiler, connection, as_path=False) return {"$replaceAll": {"input": expression, "find": text, "replacement": replacement}} -def round_(self, compiler, connection, as_path=False): # noqa: ARG001 +def round_(self, compiler, connection): # Round needs its own function because it's a special case that inherits # from Transform but has two arguments. return { @@ -202,13 +189,13 @@ def round_(self, compiler, connection, as_path=False): # noqa: ARG001 } -def str_index(self, compiler, connection, as_path=False): # noqa: ARG001 - lhs = process_lhs(self, compiler, connection) +def str_index(self, compiler, connection): + lhs = process_lhs(self, compiler, connection, as_path=False) # StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB. return {"$add": [{"$indexOfCP": lhs}, 1]} -def substr(self, compiler, connection, as_path=False): # noqa: ARG001 +def substr(self, compiler, connection): lhs = process_lhs(self, compiler, connection) # The starting index is zero-indexed on MongoDB rather than one-indexed. lhs[1] = {"$add": [lhs[1], -1]} @@ -220,14 +207,14 @@ def substr(self, compiler, connection, as_path=False): # noqa: ARG001 def trim(operator): - def wrapped(self, compiler, connection, as_path=False): # noqa: ARG001 + def wrapped(self, compiler, connection): lhs = process_lhs(self, compiler, connection) return {f"${operator}": {"input": lhs}} return wrapped -def trunc(self, compiler, connection, as_path=False): # noqa: ARG001 +def trunc(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) lhs_mql = {"date": lhs_mql, "unit": self.kind, "startOfWeek": "mon"} if timezone := self.get_tzname(): @@ -265,9 +252,9 @@ def trunc_convert_value(self, value, expression, connection): return _trunc_convert_value(self, value, expression, connection) -def trunc_date(self, compiler, connection, **extra): # noqa: ARG001 +def trunc_date(self, compiler, connection): # Cast to date rather than truncate to date. - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncDate with tzinfo ({tzname}) isn't supported on MongoDB.") @@ -286,7 +273,7 @@ def trunc_date(self, compiler, connection, **extra): # noqa: ARG001 } -def trunc_time(self, compiler, connection, as_path=False): # noqa: ARG001 +def trunc_time(self, compiler, connection): tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncTime with tzinfo ({tzname}) isn't supported on MongoDB.") @@ -306,29 +293,35 @@ def trunc_time(self, compiler, connection, as_path=False): # noqa: ARG001 } +def is_simple_expression(self): # noqa: ARG001 + return False + + def register_functions(): - Cast.as_mql = cast - Concat.as_mql = concat - ConcatPair.as_mql = concat_pair - Cot.as_mql = cot - Extract.as_mql = extract - Func.as_mql = func - JSONArray.as_mql = process_lhs - Left.as_mql = left - Length.as_mql = length - Log.as_mql = log - Lower.as_mql = preserve_null("toLower") - LTrim.as_mql = trim("ltrim") - Now.as_mql = now - NullIf.as_mql = null_if - Replace.as_mql = replace - Round.as_mql = round_ - RTrim.as_mql = trim("rtrim") - StrIndex.as_mql = str_index - Substr.as_mql = substr - Trim.as_mql = trim("trim") - TruncBase.as_mql = trunc + Cast.as_mql_expr = cast + Concat.as_mql_expr = concat + ConcatPair.as_mql_expr = concat_pair + Cot.as_mql_expr = cot + Extract.as_mql_expr = extract + Func.as_mql_path = func_path + Func.as_mql_expr = func_expr + JSONArray.as_mql_expr = process_lhs + Left.as_mql_expr = left + Length.as_mql_expr = length + Log.as_mql_expr = log + Lower.as_mql_expr = preserve_null("toLower") + LTrim.as_mql_expr = trim("ltrim") + Now.as_mql_expr = now + NullIf.as_mql_expr = null_if + Replace.as_mql_expr = replace + Round.as_mql_expr = round_ + RTrim.as_mql_expr = trim("rtrim") + StrIndex.as_mql_expr = str_index + Substr.as_mql_expr = substr + Trim.as_mql_expr = trim("trim") + TruncBase.as_mql_expr = trunc TruncBase.convert_value = trunc_convert_value - TruncDate.as_mql = trunc_date - TruncTime.as_mql = trunc_time - Upper.as_mql = preserve_null("toUpper") + TruncDate.as_mql_expr = trunc_date + TruncTime.as_mql_expr = trunc_time + Upper.as_mql_expr = preserve_null("toUpper") + Func.is_simple_expression = is_simple_expression diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 85c25e229..51158bc63 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -4,23 +4,23 @@ BuiltinLookup, FieldGetDbPrepValueIterableMixin, IsNull, + Lookup, PatternLookup, UUIDTextMixin, ) -from .query_utils import process_lhs, process_rhs +from .query_utils import is_simple_expression, process_lhs, process_rhs -def builtin_lookup(self, compiler, connection, as_path=False): - if as_path and self.is_simple_expression(): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - value = process_rhs(self, compiler, connection, as_path=True) - return connection.mongo_operators_match[self.lookup_name](lhs_mql, value) +def builtin_lookup_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return connection.mongo_operators_match[self.lookup_name](lhs_mql, value) - value = process_rhs(self, compiler, connection) + +def builtin_lookup_expr(self, compiler, connection): + value = process_rhs(self, compiler, connection, as_path=False) lhs_mql = process_lhs(self, compiler, connection, as_path=False) - if as_path: - return {"$expr": connection.mongo_operators_expr[self.lookup_name](lhs_mql, value)} return connection.mongo_operators_expr[self.lookup_name](lhs_mql, value) @@ -40,14 +40,17 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param): return sql, sql_params -def in_(self, compiler, connection, **extra): - db_rhs = getattr(self.rhs, "_db", None) - if db_rhs is not None and db_rhs != connection.alias: - raise ValueError( - "Subqueries aren't allowed across different databases. Force " - "the inner query to be evaluated using `list(inner_query)`." - ) - return builtin_lookup(self, compiler, connection, **extra) +def wrap_in(function): + def inner(self, compiler, connection): + db_rhs = getattr(self.rhs, "_db", None) + if db_rhs is not None and db_rhs != connection.alias: + raise ValueError( + "Subqueries aren't allowed across different databases. Force " + "the inner query to be evaluated using `list(inner_query)`." + ) + return function(self, compiler, connection) + + return inner def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001 @@ -82,15 +85,17 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) ] -def is_null(self, compiler, connection, as_path=False): +def is_null_path(self, compiler, connection): + if not isinstance(self.rhs, bool): + raise ValueError("The QuerySet value for an isnull lookup must be True or False.") + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return connection.mongo_operators_match["isnull"](lhs_mql, self.rhs) + + +def is_null_expr(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") - if as_path and self.is_simple_expression(): - lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) - return connection.mongo_operators_match["isnull"](lhs_mql, self.rhs) lhs_mql = process_lhs(self, compiler, connection, as_path=False) - if as_path: - return {"$expr": connection.mongo_operators_expr["isnull"](lhs_mql, self.rhs)} return connection.mongo_operators_expr["isnull"](lhs_mql, self.rhs) @@ -134,12 +139,17 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001 def register_lookups(): - BuiltinLookup.as_mql = builtin_lookup + Lookup.is_simple_expression = is_simple_expression + BuiltinLookup.as_mql_path = builtin_lookup_path + BuiltinLookup.as_mql_expr = builtin_lookup_expr FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = ( field_resolve_expression_parameter ) - In.as_mql = RelatedIn.as_mql = in_ + In.as_mql_path = RelatedIn.as_mql_path = wrap_in(builtin_lookup_path) + In.as_mql_expr = RelatedIn.as_mql_expr = wrap_in(builtin_lookup_expr) In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline - IsNull.as_mql = is_null + IsNull.as_mql_path = is_null_path + IsNull.as_mql_expr = is_null_expr PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value + # Patching the main method, it is not supported yet. UUIDTextMixin.as_mql = uuid_text_mixin diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index a013c3375..cb602784d 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -167,6 +167,7 @@ def _get_reroot_replacements(expression): target.remote_field = col.target.remote_field if parent_pos is not None: column_target = Col(None, target) + column_target.is_simple_column = False target_col = f"${parent_template}{parent_pos}" column_target.target.db_column = target_col column_target.target.set_attributes_from_name(target_col) diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index d5943f61c..68a0704d3 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -1,13 +1,14 @@ from django.core.exceptions import FullResultSet from django.db.models.aggregates import Aggregate -from django.db.models.expressions import Value +from django.db.models.expressions import CombinedExpression, Value +from django.db.models.sql.query import Query def is_direct_value(node): return not hasattr(node, "as_sql") -def process_lhs(node, compiler, connection, **extra): +def process_lhs(node, compiler, connection, as_path=False): if not hasattr(node, "lhs"): # node is a Func or Expression, possibly with multiple source expressions. result = [] @@ -15,16 +16,16 @@ def process_lhs(node, compiler, connection, **extra): if expr is None: continue try: - result.append(expr.as_mql(compiler, connection, **extra)) + result.append(expr.as_mql(compiler, connection, as_path=as_path)) except FullResultSet: - result.append(Value(True).as_mql(compiler, connection, **extra)) + result.append(Value(True).as_mql(compiler, connection, as_path=as_path)) if isinstance(node, Aggregate): return result[0] return result # node is a Transform with just one source expression, aliased as "lhs". if is_direct_value(node.lhs): return node - return node.lhs.as_mql(compiler, connection, **extra) + return node.lhs.as_mql(compiler, connection, as_path=as_path) def process_rhs(node, compiler, connection, as_path=False): @@ -60,3 +61,37 @@ def regex_match(field, regex, insensitive=False): options = "i" if insensitive else "" # return {"$regexMatch": {"input": field, "regex": regex, "options": options}} return {field: {"$regex": regex, "$options": options}} + + +def is_constant_value(value): + if isinstance(value, CombinedExpression): + # Temporary: treat all CombinedExpressions as non-constant until + # constant cases are handled + return False + if isinstance(value, list): + return all(map(is_constant_value, value)) + if is_direct_value(value): + return True + if hasattr(value, "get_source_expressions"): + # Temporary: similar limitation as above, sub-expressions should be + # resolved in the future + simple_sub_expressions = all(map(is_constant_value, value.get_source_expressions())) + else: + simple_sub_expressions = True + return ( + simple_sub_expressions + and isinstance(value, Value) + and not ( + isinstance(value, Query) + or value.contains_aggregate + or value.contains_over_clause + or value.contains_column_references + or value.contains_subquery + ) + ) + + +def is_simple_expression(self): + simple_column = getattr(self.lhs, "is_simple_column", False) + constant_value = is_constant_value(self.rhs) + return simple_column and constant_value diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index 5ae396e2a..de8713d08 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -207,7 +207,9 @@ def test_filter_unsupported_lookups_in_json(self): kwargs = {f"main_section__artifacts__metadata__origin__{lookup}": ["Pergamon", "Egypt"]} with CaptureQueriesContext(connection) as captured_queries: self.assertCountEqual(Exhibit.objects.filter(**kwargs), []) - self.assertIn(f"'field': '{lookup}'", captured_queries[0]["sql"]) + self.assertIn( + f"'main_section.artifacts.metadata.origin.{lookup}':", captured_queries[0]["sql"] + ) def test_len(self): self.assertCountEqual(Exhibit.objects.filter(sections__len=10), []) From 0e93dbe84ea74711360684e6dd61cd611cf67b6d Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 26 Sep 2025 00:16:30 -0300 Subject: [PATCH 03/23] Edits. --- django_mongodb_backend/expressions/builtins.py | 2 +- django_mongodb_backend/fields/embedded_model_array.py | 4 ++++ tests/model_fields_/test_embedded_model.py | 9 +++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 4255e1fe3..724dc1a0e 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -172,7 +172,7 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001 @property def ref_is_simple_column(self): - return isinstance(self.source, Col) and self.source.alias is not None + return self.source.is_simple_column def star(self, compiler, connection, as_path=False): # noqa: ARG001 diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index 958359642..2f94e2a4d 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -8,6 +8,7 @@ from django.utils.functional import cached_property from .. import forms +from ..lookups import builtin_lookup_path from ..query_utils import process_lhs, process_rhs from . import EmbeddedModelField from .array import ArrayField, ArrayLenTransform @@ -142,6 +143,9 @@ def as_mql_expr(self, compiler, connection): ) return {"$anyElementTrue": lhs_mql} + def as_mql_path(self, compiler, connection): + return builtin_lookup_path(self, compiler, connection) + @_EmbeddedModelArrayOutputField.register_lookup class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In): diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 1a219613f..501620382 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -244,6 +244,15 @@ def test_nested(self): ) self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) + def test_annotate(self): + obj = Book.objects.create( + author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY")) + ) + book_from_ny = ( + Book.objects.annotate(city=F("author__address__city")).filter(city="NYC").first() + ) + self.assertCountEqual(book_from_ny.city, obj.author.address.city) + class ArrayFieldTests(TestCase): @classmethod From a414d935b4960b560d973d29fc1df4462e485186 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 26 Sep 2025 00:26:51 -0300 Subject: [PATCH 04/23] Add example generated query test. --- tests/model_fields_/test_embedded_model_array.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index de8713d08..f976e9b6d 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -317,6 +317,13 @@ def test_foreign_field_exact(self): qs = Tour.objects.filter(exhibit__sections__number=1) self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + def test_foreign_field_exact_expr(self): + """Querying from a foreign key to an EmbeddedModelArrayField.""" + with self.assertNumQueries(1) as ctx: + qs = Tour.objects.filter(exhibit__sections__number=Value(2) - Value(1)) + self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + self.assertIn("anyElementTrue", ctx.captured_queries[0]["sql"]) + def test_foreign_field_with_slice(self): qs = Tour.objects.filter(exhibit__sections__0_2__number__in=[1, 2]) self.assertCountEqual(qs, [self.wonders_tour, self.egypt_tour]) From e2020546767b4fb57a58a758d79452ebf93e66ce Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 26 Sep 2025 01:26:03 -0300 Subject: [PATCH 05/23] Refactor. --- django_mongodb_backend/compiler.py | 4 +--- django_mongodb_backend/lookups.py | 8 +++++++- django_mongodb_backend/query.py | 4 +--- django_mongodb_backend/query_utils.py | 6 ------ 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 2eb6dfd99..66d715258 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -643,9 +643,7 @@ def get_combinator_queries(self): for alias, expr in self.columns: # Unfold foreign fields. if isinstance(expr, Col) and expr.alias != self.collection_name: - ids[expr.alias][expr.target.column] = expr.as_mql( - self, self.connection, as_path=False - ) + ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection) else: ids[alias] = f"${alias}" # Convert defaultdict to dict so it doesn't appear as diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 51158bc63..88f611bf5 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -9,7 +9,7 @@ UUIDTextMixin, ) -from .query_utils import is_simple_expression, process_lhs, process_rhs +from .query_utils import is_constant_value, process_lhs, process_rhs def builtin_lookup_path(self, compiler, connection): @@ -138,6 +138,12 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("Pattern lookups on UUIDField are not supported.") +def is_simple_expression(self): + simple_column = getattr(self.lhs, "is_simple_column", False) + constant_value = is_constant_value(self.rhs) + return simple_column and constant_value + + def register_lookups(): Lookup.is_simple_expression = is_simple_expression BuiltinLookup.as_mql_path = builtin_lookup_path diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index cb602784d..ef3312cf1 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -88,7 +88,6 @@ def get_pipeline(self): pipeline.extend(query.get_pipeline()) if self.match_mql: pipeline.append({"$match": self.match_mql}) - # pipeline.append({"$match": {"$expr": self.match_mql}}) if self.aggregation_pipeline: pipeline.extend(self.aggregation_pipeline) if self.project_fields: @@ -165,14 +164,13 @@ def _get_reroot_replacements(expression): for col, parent_pos in columns: target = col.target.clone() target.remote_field = col.target.remote_field + column_target = Col(compiler.collection_name, target) if parent_pos is not None: - column_target = Col(None, target) column_target.is_simple_column = False target_col = f"${parent_template}{parent_pos}" column_target.target.db_column = target_col column_target.target.set_attributes_from_name(target_col) else: - column_target = Col(compiler.collection_name, target) column_target.target = col.target replacements[col] = column_target return replacements diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 68a0704d3..cb1a5711a 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -89,9 +89,3 @@ def is_constant_value(value): or value.contains_subquery ) ) - - -def is_simple_expression(self): - simple_column = getattr(self.lhs, "is_simple_column", False) - constant_value = is_constant_value(self.rhs) - return simple_column and constant_value From ec40c7dc6d5cbd994ccbcf869c1c5478e9fba23d Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 26 Sep 2025 21:47:24 -0300 Subject: [PATCH 06/23] Update django_mongodb_backend/functions.py Co-authored-by: Tim Graham --- django_mongodb_backend/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index a2c18a986..17354bfae 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -117,7 +117,7 @@ def func(self, compiler, connection, as_path=False): def func_path(self, compiler, connection): # noqa: ARG001 - raise NotSupportedError(f"{self} May need an as_mql_path() method.") + raise NotSupportedError(f"{self} may need an as_mql_path() method.") def func_expr(self, compiler, connection): From 2aabf8d5a30ec4daf7a1875676105ebad9c61494 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 26 Sep 2025 23:04:40 -0300 Subject: [PATCH 07/23] Edits. --- django_mongodb_backend/aggregates.py | 13 +- django_mongodb_backend/base.py | 5 +- .../expressions/builtins.py | 26 +-- django_mongodb_backend/fields/array.py | 32 ++-- .../fields/embedded_model.py | 20 +- .../fields/embedded_model_array.py | 17 +- django_mongodb_backend/fields/json.py | 77 ++++---- django_mongodb_backend/functions.py | 3 +- django_mongodb_backend/lookups.py | 8 +- .../query_conversion/__init__.py | 0 .../query_conversion/expression_converters.py | 172 ------------------ .../query_conversion/query_optimizer.py | 73 -------- django_mongodb_backend/query_utils.py | 1 - 13 files changed, 101 insertions(+), 346 deletions(-) delete mode 100644 django_mongodb_backend/query_conversion/__init__.py delete mode 100644 django_mongodb_backend/query_conversion/expression_converters.py delete mode 100644 django_mongodb_backend/query_conversion/query_optimizer.py diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index 67ccbb499..2d1dd6afe 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -8,7 +8,14 @@ MONGO_AGGREGATIONS = {Count: "sum"} -def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False): +def aggregate( + self, + compiler, + connection, + operator=None, + resolve_inner_expression=False, + **extra_context, # noqa: ARG001 +): if self.filter: node = self.copy() node.filter = None @@ -24,7 +31,7 @@ def aggregate(self, compiler, connection, operator=None, resolve_inner_expressio return {f"${operator}": lhs_mql} -def count(self, compiler, connection, resolve_inner_expression=False): +def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001 """ When resolve_inner_expression=True, return the MQL that resolves as a value. This is used to count different elements, so the inner values are @@ -57,7 +64,7 @@ def count(self, compiler, connection, resolve_inner_expression=False): return {"$add": [{"$size": lhs_mql}, exits_null]} -def stddev_variance(self, compiler, connection): +def stddev_variance(self, compiler, connection, **extra_context): # noqa: ARG001 if self.function.endswith("_SAMP"): operator = "stdDevSamp" elif self.function.endswith("_POP"): diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index f882e1eec..58a843cf6 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -113,7 +113,7 @@ def _isnull_operator_match(a, b): return {"$or": [{a: {"$exists": False}}, {a: None}]} return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]} - mongo_operators_expr = { + mongo_expr_operators = { "exact": lambda a, b: {"$eq": [a, b]}, "gt": lambda a, b: {"$gt": [a, b]}, "gte": lambda a, b: {"$gte": [a, b]}, @@ -153,7 +153,8 @@ def range_match(a, b): return {"$literal": True} return {"$and": conditions} - mongo_operators_match = { + # match, path, find? don't know which name use. + mongo_match_operators = { "exact": lambda a, b: {a: b}, "gt": lambda a, b: {a: {"$gt": b}}, "gte": lambda a, b: {a: {"$gte": b}}, diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 724dc1a0e..1556dc9c6 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -29,6 +29,18 @@ from ..query_utils import process_lhs +def base_expression(self, compiler, connection, as_path=False, **extra): + if ( + as_path + and hasattr(self, "as_mql_path") + and getattr(self, "is_simple_expression", lambda: False)() + ): + return self.as_mql_path(compiler, connection, **extra) + + expr = self.as_mql_expr(compiler, connection, **extra) + return {"$expr": expr} if as_path else expr + + def case(self, compiler, connection): case_parts = [] for case in self.cases: @@ -190,7 +202,7 @@ def exists(self, compiler, connection, get_wrapping_pipeline=None): lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) except EmptyResultSet: return Value(False).as_mql(compiler, connection) - return connection.mongo_operators_expr["isnull"](lhs_mql, False) + return connection.mongo_expr_operators["isnull"](lhs_mql, False) def when(self, compiler, connection, as_path=False): @@ -221,18 +233,6 @@ def value(self, compiler, connection, as_path=False): # noqa: ARG001 return value -def base_expression(self, compiler, connection, as_path=False, **extra): - if ( - as_path - and hasattr(self, "as_mql_path") - and getattr(self, "is_simple_expression", lambda: False)() - ): - return self.as_mql_path(compiler, connection, **extra) - - expr = self.as_mql_expr(compiler, connection, **extra) - return {"$expr": expr} if as_path else expr - - def register_expressions(): BaseExpression.as_mql = base_expression BaseExpression.is_simple_column = False diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 4d2a42cb5..3ffb85f75 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -254,13 +254,6 @@ def __init__(self, lhs, rhs): class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contains" - def as_mql_path(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - value = process_rhs(self, compiler, connection, as_path=True) - if value is None: - return False - return {lhs_mql: {"$all": value}} - def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) value = process_rhs(self, compiler, connection, as_path=False) @@ -272,6 +265,13 @@ def as_mql_expr(self, compiler, connection): ] } + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + if value is None: + return False + return {lhs_mql: {"$all": value}} + @ArrayField.register_lookup class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): @@ -333,11 +333,6 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) }, ] - def as_mql_path(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - value = process_rhs(self, compiler, connection, as_path=True) - return {lhs_mql: {"$in": value}} - def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) value = process_rhs(self, compiler, connection, as_path=False) @@ -348,6 +343,11 @@ def as_mql_expr(self, compiler, connection): ] } + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return {lhs_mql: {"$in": value}} + @ArrayField.register_lookup class ArrayLenTransform(Transform): @@ -388,14 +388,14 @@ def is_simple_expression(self): def is_simple_column(self): return self.lhs.is_simple_column - def as_mql_path(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - return f"{lhs_mql}.{self.index}" - def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) return {"$arrayElemAt": [lhs_mql, self.index]} + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return f"{lhs_mql}.{self.index}" + @property def output_field(self): return self.base_field diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index a6f06e096..e66d08fb2 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -198,16 +198,6 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql_path(self, compiler, connection): - previous = self - key_transforms = [] - while isinstance(previous, EmbeddedModelTransform): - key_transforms.insert(0, previous.key_name) - previous = previous.lhs - mql = previous.as_mql(compiler, connection, as_path=True) - mql_path = ".".join(key_transforms) - return f"{mql}.{mql_path}" - def as_mql_expr(self, compiler, connection): previous = self columns = [] @@ -219,6 +209,16 @@ def as_mql_expr(self, compiler, connection): mql = {"$getField": {"input": mql, "field": column}} return mql + def as_mql_path(self, compiler, connection): + previous = self + key_transforms = [] + while isinstance(previous, EmbeddedModelTransform): + key_transforms.insert(0, previous.key_name) + previous = previous.lhs + mql = previous.as_mql(compiler, connection, as_path=True) + mql_path = ".".join(key_transforms) + return f"{mql}.{mql_path}" + @property def output_field(self): return self._field diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index 2f94e2a4d..228f35869 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -138,7 +138,7 @@ def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"] values = process_rhs(self, compiler, connection) - lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators_expr[self.lookup_name]( + lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_expr_operators[self.lookup_name]( inner_lhs_mql, values ) return {"$anyElementTrue": lhs_mql} @@ -231,7 +231,6 @@ class EmbeddedModelArrayFieldLessThanOrEqual( class EmbeddedModelArrayFieldTransform(Transform): field_class_name = "EmbeddedModelArrayField" - PREFIX_ITERABLE = "item" def __init__(self, field, *args, **kwargs): super().__init__(*args, **kwargs) @@ -292,13 +291,6 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql_path(self, compiler, connection): - inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True).removeprefix( - f"${self.PREFIX_ITERABLE}." - ) - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - return f"{lhs_mql}.{inner_lhs_mql}" - def as_mql_expr(self, compiler, connection): inner_lhs_mql = self._lhs.as_mql(compiler, connection) lhs_mql = process_lhs(self, compiler, connection) @@ -307,7 +299,7 @@ def as_mql_expr(self, compiler, connection): { "$map": { "input": lhs_mql, - "as": self.PREFIX_ITERABLE, + "as": "item", "in": inner_lhs_mql, } }, @@ -315,6 +307,11 @@ def as_mql_expr(self, compiler, connection): ] } + def as_mql_path(self, compiler, connection): + inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True).removeprefix("$item.") + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return f"{lhs_mql}.{inner_lhs_mql}" + @property def output_field(self): return _EmbeddedModelArrayOutputField(self._lhs.output_field) diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index f9a80999c..1c8b4c5a3 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -92,8 +92,9 @@ def has_key_lookup(self, compiler, connection, as_path=False): keys = [] for path in paths: keys.append(_has_key_predicate(path, lhs, as_path=as_path)) - - return keys[0] if self.mongo_operator is None else {self.mongo_operator: keys} + if self.mongo_operator is None: + return keys[0] + return {self.mongo_operator: keys} _process_rhs = JSONExact.process_rhs @@ -119,14 +120,28 @@ def key_transform(self, compiler, connection, as_path=False): """ key_transforms = [self.key_name] previous = self.lhs + # Collect all key transforms in order. while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs - # Collect all key transforms in order. lhs_mql = previous.as_mql(compiler, connection, as_path=as_path) return build_json_mql_path(lhs_mql, key_transforms, as_path=as_path) +def key_transform_exact_expr(self, compiler, connection): + return builtin_lookup_expr(self, compiler, connection) + + +def key_transform_exact_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return { + "$and": [ + builtin_lookup_path(self, compiler, connection), + _has_key_predicate(lhs_mql, None, as_path=True), + ] + } + + def key_transform_in(self, compiler, connection, as_path=False): """ Return MQL to check if a JSON path exists and that its values are in the @@ -143,17 +158,13 @@ def key_transform_in(self, compiler, connection, as_path=False): root_column = previous.as_mql(compiler, connection) value = process_rhs(self, compiler, connection) # Construct the expression to check if lhs_mql values are in rhs values. - expr = connection.mongo_operators_expr[self.lookup_name](lhs_mql, value) + expr = connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) expr = {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} if as_path: return {"$expr": expr} return expr -def key_transform_in_path(self, compiler, connection): - return builtin_lookup_path(self, compiler, connection) - - def key_transform_in_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) # Traverse to the root column. @@ -163,23 +174,12 @@ def key_transform_in_expr(self, compiler, connection): root_column = previous.as_mql(compiler, connection) value = process_rhs(self, compiler, connection) # Construct the expression to check if lhs_mql values are in rhs values. - expr = connection.mongo_operators_expr[self.lookup_name](lhs_mql, value) + expr = connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) return {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} -def key_transform_is_null_path(self, compiler, connection): - """ - Return MQL to check the nullability of a key. - - If `isnull=True`, the query matches objects where the key is missing or the - root column is null. If `isnull=False`, the query negates the result to - match objects where the key exists. - - Reference: https://code.djangoproject.com/ticket/32252 - """ - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - rhs_mql = process_rhs(self, compiler, connection, as_path=True) - return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True) +def key_transform_in_path(self, compiler, connection): + return builtin_lookup_path(self, compiler, connection) def key_transform_is_null_expr(self, compiler, connection): @@ -192,12 +192,19 @@ def key_transform_is_null_expr(self, compiler, connection): return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) -def key_transform_numeric_lookup_mixin_path(self, compiler, connection): +def key_transform_is_null_path(self, compiler, connection): """ - Return MQL to check if the field exists (i.e., is not "missing" or "null") - and that the field matches the given numeric lookup expression. + Return MQL to check the nullability of a key. + + If `isnull=True`, the query matches objects where the key is missing or the + root column is null. If `isnull=False`, the query negates the result to + match objects where the key exists. + + Reference: https://code.djangoproject.com/ticket/32252 """ - return builtin_lookup_path(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + rhs_mql = process_rhs(self, compiler, connection, as_path=True) + return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True) def key_transform_numeric_lookup_mixin_expr(self, compiler, connection): @@ -212,18 +219,8 @@ def key_transform_numeric_lookup_mixin_expr(self, compiler, connection): return {"$and": [expr, not_missing_or_null]} -def key_transform_exact_path(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - return { - "$and": [ - builtin_lookup_path(self, compiler, connection), - _has_key_predicate(lhs_mql, None, as_path=True), - ] - } - - -def key_transform_exact_expr(self, compiler, connection): - return builtin_lookup_expr(self, compiler, connection) +def key_transform_numeric_lookup_mixin_path(self, compiler, connection): + return builtin_lookup_path(self, compiler, connection) def keytransform_is_simple_column(self): @@ -249,11 +246,11 @@ def register_json_field(): KeyTransform.is_simple_expression = keytransform_is_simple_column KeyTransform.as_mql_path = partialmethod(key_transform, as_path=True) KeyTransform.as_mql_expr = partialmethod(key_transform, as_path=False) + KeyTransformExact.as_mql_expr = key_transform_exact_expr + KeyTransformExact.as_mql_path = key_transform_exact_path KeyTransformIn.as_mql_path = key_transform_in_path KeyTransformIn.as_mql_expr = key_transform_in_expr KeyTransformIsNull.as_mql_path = key_transform_is_null_path KeyTransformIsNull.as_mql_expr = key_transform_is_null_expr KeyTransformNumericLookupMixin.as_mql_path = key_transform_numeric_lookup_mixin_path KeyTransformNumericLookupMixin.as_mql_expr = key_transform_numeric_lookup_mixin_expr - KeyTransformExact.as_mql_expr = key_transform_exact_expr - KeyTransformExact.as_mql_path = key_transform_exact_path diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index 17354bfae..56a027b7b 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -77,7 +77,6 @@ def cast(self, compiler, connection): lhs_mql = {"$convert": {"input": lhs_mql, "to": output_type}} if decimal_places := getattr(self.output_field, "decimal_places", None): lhs_mql = {"$trunc": [lhs_mql, decimal_places]} - return lhs_mql @@ -164,7 +163,7 @@ def wrapped(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=False) return { "$cond": { - "if": connection.mongo_operators_expr["isnull"](lhs_mql, True), + "if": connection.mongo_expr_operators["isnull"](lhs_mql, True), "then": None, "else": {f"${operator}": lhs_mql}, } diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 88f611bf5..83d0ea18f 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -15,13 +15,13 @@ def builtin_lookup_path(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=True) value = process_rhs(self, compiler, connection, as_path=True) - return connection.mongo_operators_match[self.lookup_name](lhs_mql, value) + return connection.mongo_match_operators[self.lookup_name](lhs_mql, value) def builtin_lookup_expr(self, compiler, connection): value = process_rhs(self, compiler, connection, as_path=False) lhs_mql = process_lhs(self, compiler, connection, as_path=False) - return connection.mongo_operators_expr[self.lookup_name](lhs_mql, value) + return connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) _field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter @@ -89,14 +89,14 @@ def is_null_path(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") lhs_mql = process_lhs(self, compiler, connection, as_path=True) - return connection.mongo_operators_match["isnull"](lhs_mql, self.rhs) + return connection.mongo_match_operators["isnull"](lhs_mql, self.rhs) def is_null_expr(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") lhs_mql = process_lhs(self, compiler, connection, as_path=False) - return connection.mongo_operators_expr["isnull"](lhs_mql, self.rhs) + return connection.mongo_expr_operators["isnull"](lhs_mql, self.rhs) # from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4 diff --git a/django_mongodb_backend/query_conversion/__init__.py b/django_mongodb_backend/query_conversion/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/django_mongodb_backend/query_conversion/expression_converters.py b/django_mongodb_backend/query_conversion/expression_converters.py deleted file mode 100644 index dc9b7df5b..000000000 --- a/django_mongodb_backend/query_conversion/expression_converters.py +++ /dev/null @@ -1,172 +0,0 @@ -class BaseConverter: - """Base class for $expr to $match converters.""" - - @classmethod - def convert(cls, expr): - raise NotImplementedError("Subclasses must implement this method.") - - @classmethod - def is_simple_value(cls, value): - """Is the value is a simple type (not a dict)?""" - if value is None: - return True - if isinstance(value, str) and value.startswith("$"): - return False - if isinstance(value, (list, tuple, set)): - return all(cls.is_simple_value(v) for v in value) - # TODO: Support `$getField` conversion. - return not isinstance(value, dict) - - -class BinaryConverter(BaseConverter): - """ - Base class for converting binary operations. - - For example: - "$expr": { - {"$gt": ["$price", 100]} - } - is converted to: - {"price": {"$gt": 100}} - """ - - operator: str - - @classmethod - def convert(cls, args): - if isinstance(args, list) and len(args) == 2: - field_expr, value = args - # Check if first argument is a simple field reference. - if ( - isinstance(field_expr, str) - and field_expr.startswith("$") - and cls.is_simple_value(value) - ): - field_name = field_expr[1:] # Remove the $ prefix. - if cls.operator == "$eq": - return {field_name: value} - return {field_name: {cls.operator: value}} - return None - - -class EqConverter(BinaryConverter): - """ - Convert $eq operation to a $match query. - - For example: - "$expr": { - {"$eq": ["$status", "active"]} - } - is converted to: - {"status": "active"} - """ - - operator = "$eq" - - -class GtConverter(BinaryConverter): - operator = "$gt" - - -class GteConverter(BinaryConverter): - operator = "$gte" - - -class LtConverter(BinaryConverter): - operator = "$lt" - - -class LteConverter(BinaryConverter): - operator = "$lte" - - -class InConverter(BaseConverter): - """ - Convert $in operation to a $match query. - - For example: - "$expr": { - {"$in": ["$category", ["electronics", "books"]]} - } - is converted to: - {"category": {"$in": ["electronics", "books"]}} - """ - - @classmethod - def convert(cls, in_args): - if isinstance(in_args, list) and len(in_args) == 2: - field_expr, values = in_args - # Check if first argument is a simple field reference. - if isinstance(field_expr, str) and field_expr.startswith("$"): - field_name = field_expr[1:] # Remove the $ prefix. - if isinstance(values, (list, tuple, set)) and all( - cls.is_simple_value(v) for v in values - ): - return {field_name: {"$in": values}} - return None - - -class LogicalConverter(BaseConverter): - """ - Base class for converting logical operations to a $match query. - - For example: - "$expr": { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - ] - } - is converted to: - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - ] - """ - - @classmethod - def convert(cls, combined_conditions): - if isinstance(combined_conditions, list): - optimized_conditions = [] - for condition in combined_conditions: - if isinstance(condition, dict) and len(condition) == 1: - if optimized_condition := convert_expression(condition): - optimized_conditions.append(optimized_condition) - else: - # Any failure should stop optimization. - return None - if optimized_conditions: - return {cls._logical_op: optimized_conditions} - return None - - -class OrConverter(LogicalConverter): - _logical_op = "$or" - - -class AndConverter(LogicalConverter): - _logical_op = "$and" - - -OPTIMIZABLE_OPS = { - "$eq": EqConverter, - "$in": InConverter, - "$and": AndConverter, - "$or": OrConverter, - "$gt": GtConverter, - "$gte": GteConverter, - "$lt": LtConverter, - "$lte": LteConverter, -} - - -def convert_expression(expr): - """ - Optimize MQL by converting an $expr condition to $match. Return the $match - MQL, or None if not optimizable. - """ - if isinstance(expr, dict) and len(expr) == 1: - op = next(iter(expr.keys())) - if op in OPTIMIZABLE_OPS: - return OPTIMIZABLE_OPS[op].convert(expr[op]) - return None diff --git a/django_mongodb_backend/query_conversion/query_optimizer.py b/django_mongodb_backend/query_conversion/query_optimizer.py deleted file mode 100644 index 368c89504..000000000 --- a/django_mongodb_backend/query_conversion/query_optimizer.py +++ /dev/null @@ -1,73 +0,0 @@ -from .expression_converters import convert_expression - - -def convert_expr_to_match(query): - """ - Optimize an MQL query by converting conditions into a list of $match - stages. - """ - if "$expr" not in query: - return [query] - if query["$expr"] == {}: - return [{"$match": {}}] - return _process_expression(query["$expr"]) - - -def _process_expression(expr): - """Process an expression and extract optimizable conditions.""" - match_conditions = [] - remaining_conditions = [] - if isinstance(expr, dict): - has_and = "$and" in expr - has_or = "$or" in expr - # Do a top-level check for $and or $or because these should inform. - # If they fail, they should failover to a remaining conditions list. - # There's probably a better way to do this. - if has_and: - and_match_conditions = _process_logical_conditions("$and", expr["$and"]) - match_conditions.extend(and_match_conditions) - if has_or: - or_match_conditions = _process_logical_conditions("$or", expr["$or"]) - match_conditions.extend(or_match_conditions) - if not has_and and not has_or: - # Process single condition. - if optimized := convert_expression(expr): - match_conditions.append({"$match": optimized}) - else: - remaining_conditions.append({"$match": {"$expr": expr}}) - else: - # Can't optimize. - remaining_conditions.append({"$expr": expr}) - return match_conditions + remaining_conditions - - -def _process_logical_conditions(logical_op, logical_conditions): - """Process conditions within a logical array.""" - optimized_conditions = [] - match_conditions = [] - remaining_conditions = [] - for condition in logical_conditions: - _remaining_conditions = [] - if isinstance(condition, dict): - if optimized := convert_expression(condition): - optimized_conditions.append(optimized) - else: - _remaining_conditions.append(condition) - else: - _remaining_conditions.append(condition) - if _remaining_conditions: - # Any expressions that can't be optimized must remain in a $expr - # that preserves the logical operator. - if len(_remaining_conditions) > 1: - remaining_conditions.append({"$expr": {logical_op: _remaining_conditions}}) - else: - remaining_conditions.append({"$expr": _remaining_conditions[0]}) - if optimized_conditions: - optimized_conditions.extend(remaining_conditions) - if len(optimized_conditions) > 1: - match_conditions.append({"$match": {logical_op: optimized_conditions}}) - else: - match_conditions.append({"$match": optimized_conditions[0]}) - else: - match_conditions.append({"$match": {logical_op: remaining_conditions}}) - return match_conditions diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index cb1a5711a..75c3c099f 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -59,7 +59,6 @@ def regex_expr(field, regex_vals, insensitive=False): def regex_match(field, regex, insensitive=False): options = "i" if insensitive else "" - # return {"$regexMatch": {"input": field, "regex": regex, "options": options}} return {field: {"$regex": regex, "$options": options}} From c8baa5ade09b232fbef5917fc79bf4c6f899a1b2 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 26 Sep 2025 23:13:29 -0300 Subject: [PATCH 08/23] Remove query converter --- tests/expression_converter_/__init__.py | 0 .../test_match_conversion.py | 215 ---------------- .../test_op_expressions.py | 233 ------------------ 3 files changed, 448 deletions(-) delete mode 100644 tests/expression_converter_/__init__.py delete mode 100644 tests/expression_converter_/test_match_conversion.py delete mode 100644 tests/expression_converter_/test_op_expressions.py diff --git a/tests/expression_converter_/__init__.py b/tests/expression_converter_/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/expression_converter_/test_match_conversion.py b/tests/expression_converter_/test_match_conversion.py deleted file mode 100644 index e78e5c0cc..000000000 --- a/tests/expression_converter_/test_match_conversion.py +++ /dev/null @@ -1,215 +0,0 @@ -from django.test import SimpleTestCase - -from django_mongodb_backend.query_conversion.query_optimizer import convert_expr_to_match - - -class ConvertExprToMatchTests(SimpleTestCase): - def assertOptimizerEqual(self, input, expected): - result = convert_expr_to_match(input) - self.assertEqual(result, expected) - - def test_multiple_optimizable_conditions(self): - expr = { - "$expr": { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$gte": ["$price", 50]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"price": {"$gte": 50}}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_mixed_optimizable_and_non_optimizable_conditions(self): - expr = { - "$expr": { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - {"$in": ["$category", ["electronics"]]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics"]}}, - {"$expr": {"$gt": ["$price", "$min_price"]}}, - ], - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_non_optimizable_condition(self): - expr = {"$expr": {"$gt": ["$price", "$min_price"]}} - expected = [ - { - "$match": { - "$expr": {"$gt": ["$price", "$min_price"]}, - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_nested_logical_conditions(self): - expr = { - "$expr": { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$and": [{"$eq": ["$verified", True]}, {"$lte": ["$price", 50]}]}, - ] - } - } - expected = [ - { - "$match": { - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"$and": [{"verified": True}, {"price": {"$lte": 50}}]}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_complex_nested_with_non_optimizable_parts(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$views", 1000]}, - ] - }, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - ] - } - } - expected = [ - { - "$match": { - "$and": [ - { - "$or": [ - {"status": "active"}, - {"views": {"$gt": 1000}}, - ] - }, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"$expr": {"$gt": ["$price", "$min_price"]}}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_london_in_case(self): - expr = {"$expr": {"$in": ["$author_city", ["London"]]}} - expected = [{"$match": {"author_city": {"$in": ["London"]}}}] - self.assertOptimizerEqual(expr, expected) - - def test_deeply_nested_logical_operators(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "standard"]}, - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - }, - {"$eq": ["$active", True]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - { - "$or": [ - {"type": "premium"}, - { - "$and": [ - {"type": "standard"}, - {"region": {"$in": ["US", "CA"]}}, - ] - }, - ] - }, - {"active": True}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_deeply_nested_logical_operator_with_variable(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "$$standard"]}, # Not optimizable - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - }, - {"$eq": ["$active", True]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"active": True}, - { - "$expr": { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "$$standard"]}, - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - } - }, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) diff --git a/tests/expression_converter_/test_op_expressions.py b/tests/expression_converter_/test_op_expressions.py deleted file mode 100644 index ce4caf2d4..000000000 --- a/tests/expression_converter_/test_op_expressions.py +++ /dev/null @@ -1,233 +0,0 @@ -import datetime -from uuid import UUID - -from bson import Decimal128 -from django.test import SimpleTestCase - -from django_mongodb_backend.query_conversion.expression_converters import convert_expression - - -class ConversionTestCase(SimpleTestCase): - CONVERTIBLE_TYPES = { - "int": 42, - "float": 3.14, - "decimal128": Decimal128("3.14"), - "boolean": True, - "NoneType": None, - "string": "string", - "datetime": datetime.datetime.now(datetime.timezone.utc), - "duration": datetime.timedelta(days=5, hours=3), - "uuid": UUID("12345678123456781234567812345678"), - } - - def assertConversionEqual(self, input, expected): - result = convert_expression(input) - self.assertEqual(result, expected) - - def assertNotOptimizable(self, input): - result = convert_expression(input) - self.assertIsNone(result) - - def _test_conversion_various_types(self, conversion_test): - for _type, val in self.CONVERTIBLE_TYPES.items(): - with self.subTest(_type=_type, val=val): - conversion_test(val) - - -class ExpressionTests(ConversionTestCase): - def test_non_dict(self): - self.assertNotOptimizable(["$status", "active"]) - - def test_empty_dict(self): - self.assertNotOptimizable({}) - - -class EqTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$eq": ["$status", "active"]}, {"status": "active"}) - - def test_no_conversion_non_string_field(self): - self.assertNotOptimizable({"$eq": [123, "active"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$eq": ["$status", {"$gt": 5}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$eq": ["$age", _type]}, {"age": _type}) - - def _test_conversion_valid_array_type(self, _type): - self.assertConversionEqual({"$eq": ["$age", _type]}, {"age": _type}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - def test_conversion_various_array_types(self): - self._test_conversion_various_types(self._test_conversion_valid_array_type) - - -class InTests(ConversionTestCase): - def test_conversion(self): - expr = {"$in": ["$category", ["electronics", "books", "clothing"]]} - expected = {"category": {"$in": ["electronics", "books", "clothing"]}} - self.assertConversionEqual(expr, expected) - - def test_no_conversion_non_string_field(self): - self.assertNotOptimizable({"$in": [123, ["electronics", "books"]]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$in": ["$status", [{"bad": "val"}]]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$in": ["$age", [_type]]}, {"age": {"$in": [_type]}}) - - def test_conversion_various_types(self): - for _type, val in self.CONVERTIBLE_TYPES.items(): - with self.subTest(_type=_type, val=val): - self._test_conversion_valid_type(val) - - -class LogicalTests(ConversionTestCase): - def test_and(self): - expr = { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - ] - } - expected = { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - ] - } - self.assertConversionEqual(expr, expected) - - def test_or(self): - expr = { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - ] - } - expected = { - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - ] - } - self.assertConversionEqual(expr, expected) - - def test_or_failure(self): - expr = { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - { - "$and": [ - {"verified": True}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - ] - }, - ] - } - self.assertNotOptimizable(expr) - - def test_mixed(self): - expr = { - "$and": [ - { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$views", 1000]}, - ] - }, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$lte": ["$price", 2000]}, - ] - } - expected = { - "$and": [ - {"$or": [{"status": "active"}, {"views": {"$gt": 1000}}]}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"price": {"$lte": 2000}}, - ] - } - self.assertConversionEqual(expr, expected) - - -class GtTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$gt": ["$price", 100]}, {"price": {"$gt": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$gt": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$gt": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$gt": ["$price", _type]}, {"price": {"$gt": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class GteTests(ConversionTestCase): - def test_conversion(self): - expr = {"$gte": ["$price", 100]} - expected = {"price": {"$gte": 100}} - self.assertConversionEqual(expr, expected) - - def test_no_conversion_non_simple_field(self): - expr = {"$gte": ["$price", "$min_price"]} - self.assertNotOptimizable(expr) - - def test_no_conversion_dict_value(self): - expr = {"$gte": ["$price", {}]} - self.assertNotOptimizable(expr) - - def _test_conversion_valid_type(self, _type): - expr = {"$gte": ["$price", _type]} - expected = {"price": {"$gte": _type}} - self.assertConversionEqual(expr, expected) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class LtTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$lt": ["$price", 100]}, {"price": {"$lt": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$lt": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$lt": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$lt": ["$price", _type]}, {"price": {"$lt": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class LteTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$lte": ["$price", 100]}, {"price": {"$lte": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$lte": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$lte": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$lte": ["$price", _type]}, {"price": {"$lte": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) From c8a1c31224521c05d9b0436e066608b0c5d53ab4 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 26 Sep 2025 23:39:02 -0300 Subject: [PATCH 09/23] Refactor. --- .../fields/embedded_model.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index e66d08fb2..95896a0f9 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -198,25 +198,25 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql_expr(self, compiler, connection): + def _get_target_path(self): previous = self columns = [] while isinstance(previous, EmbeddedModelTransform): columns.insert(0, previous.field.column) previous = previous.lhs - mql = previous.as_mql(compiler, connection) - for column in columns: - mql = {"$getField": {"input": mql, "field": column}} + return columns, previous + + def as_mql_expr(self, compiler, connection): + columns, parent_field = self._get_target_path() + mql = parent_field.as_mql(compiler, connection) + for key in columns: + mql = {"$getField": {"input": mql, "field": key}} return mql def as_mql_path(self, compiler, connection): - previous = self - key_transforms = [] - while isinstance(previous, EmbeddedModelTransform): - key_transforms.insert(0, previous.key_name) - previous = previous.lhs - mql = previous.as_mql(compiler, connection, as_path=True) - mql_path = ".".join(key_transforms) + columns, parent_field = self._get_target_path() + mql = parent_field.as_mql(compiler, connection, as_path=True) + mql_path = ".".join(columns) return f"{mql}.{mql_path}" @property From 1693ea977951b86512a9f80179b892e0587056e4 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 27 Sep 2025 20:47:22 -0300 Subject: [PATCH 10/23] Simplify if. --- django_mongodb_backend/fields/json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 1c8b4c5a3..52f705448 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -23,7 +23,7 @@ def build_json_mql_path(lhs, key_transforms, as_path=False): # Build the MQL path using the collected key transforms. - if as_path and lhs: + if as_path: return ".".join(chain([lhs], key_transforms)) result = lhs for key in key_transforms: From cf06a24c7bc7417b1be503e9fd598febcd3cf866 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 29 Sep 2025 19:01:52 -0300 Subject: [PATCH 11/23] Handle empty set or full set in range queries. --- django_mongodb_backend/base.py | 12 +++++++----- django_mongodb_backend/fields/json.py | 17 ++++++++++------- tests/lookup_/tests.py | 19 +++++++++++++++++-- tests/model_fields_/test_embedded_model.py | 12 +++++++----- .../test_embedded_model_array.py | 8 +++++--- 5 files changed, 46 insertions(+), 22 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 58a843cf6..594fb559e 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -2,7 +2,7 @@ import logging import os -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import EmptyResultSet, FullResultSet, ImproperlyConfigured from django.db import DEFAULT_DB_ALIAS from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.utils import debug_transaction @@ -143,14 +143,16 @@ def _isnull_operator_match(a, b): } def range_match(a, b): - ## TODO: MAKE A TEST TO TEST WHEN BOTH ENDS ARE NONE. WHAT SHALL I RETURN? conditions = [] - if b[0] is not None: + start, end = b + if start is not None: conditions.append({a: {"$gte": b[0]}}) - if b[1] is not None: + if end is not None: conditions.append({a: {"$lte": b[1]}}) + if start is not None and end is not None and start > end: + raise EmptyResultSet if not conditions: - return {"$literal": True} + raise FullResultSet return {"$and": conditions} # match, path, find? don't know which name use. diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 52f705448..31b502faf 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -183,6 +183,15 @@ def key_transform_in_path(self, compiler, connection): def key_transform_is_null_expr(self, compiler, connection): + """ + Return MQL to check the nullability of a key. + + If `isnull=True`, the query matches objects where the key is missing or the + root column is null. If `isnull=False`, the query negates the result to + match objects where the key exists. + + Reference: https://code.djangoproject.com/ticket/32252 + """ previous = self.lhs while isinstance(previous, KeyTransform): previous = previous.lhs @@ -194,13 +203,7 @@ def key_transform_is_null_expr(self, compiler, connection): def key_transform_is_null_path(self, compiler, connection): """ - Return MQL to check the nullability of a key. - - If `isnull=True`, the query matches objects where the key is missing or the - root column is null. If `isnull=False`, the query negates the result to - match objects where the key exists. - - Reference: https://code.djangoproject.com/ticket/32252 + Return MQL to check the nullability of a key using the operator $exists. """ lhs_mql = process_lhs(self, compiler, connection, as_path=True) rhs_mql = process_rhs(self, compiler, connection, as_path=True) diff --git a/tests/lookup_/tests.py b/tests/lookup_/tests.py index d6e5da2a6..6166e1003 100644 --- a/tests/lookup_/tests.py +++ b/tests/lookup_/tests.py @@ -1,3 +1,4 @@ +from bson import SON from django.test import TestCase from django_mongodb_backend.test import MongoTestCaseMixin @@ -5,12 +6,12 @@ from .models import Book, Number -class NumericLookupTests(TestCase): +class NumericLookupTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.objs = Number.objects.bulk_create(Number(num=x) for x in range(5)) # Null values should be excluded in less than queries. - Number.objects.create() + cls.null_number = Number.objects.create() def test_lt(self): self.assertQuerySetEqual(Number.objects.filter(num__lt=3), self.objs[:3]) @@ -18,6 +19,20 @@ def test_lt(self): def test_lte(self): self.assertQuerySetEqual(Number.objects.filter(num__lte=3), self.objs[:4]) + def test_empty_range(self): + with self.assertNumQueries(0): + self.assertQuerySetEqual(Number.objects.filter(num__range=[3, 1]), []) + + def test_full_range(self): + with self.assertNumQueries(1) as ctx: + self.assertQuerySetEqual( + Number.objects.filter(num__range=[None, None]), [self.null_number, *self.objs] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "lookup__number", [{"$addFields": {"num": "$num"}}, {"$sort": SON([("num", 1)])}] + ) + class RegexTests(MongoTestCaseMixin, TestCase): def test_mql(self): diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 501620382..5a3bb578f 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -244,14 +244,16 @@ def test_nested(self): ) self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) - def test_annotate(self): + def test_filter_by_simple_annotate(self): obj = Book.objects.create( author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY")) ) - book_from_ny = ( - Book.objects.annotate(city=F("author__address__city")).filter(city="NYC").first() - ) - self.assertCountEqual(book_from_ny.city, obj.author.address.city) + with self.assertNumQueries(1) as ctx: + book_from_ny = ( + Book.objects.annotate(city=F("author__address__city")).filter(city="NYC").first() + ) + self.assertCountEqual(book_from_ny.city, obj.author.address.city) + self.assertIn("{'$match': {'author.address.city': 'NYC'}}", ctx.captured_queries[0]["sql"]) class ArrayFieldTests(TestCase): diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index f976e9b6d..c038f29ac 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -312,10 +312,12 @@ def test_nested_lookup(self): with self.assertRaisesMessage(ValueError, msg): Exhibit.objects.filter(sections__artifacts__name="") - def test_foreign_field_exact(self): + def test_foreign_field_exact_path(self): """Querying from a foreign key to an EmbeddedModelArrayField.""" - qs = Tour.objects.filter(exhibit__sections__number=1) - self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + with self.assertNumQueries(1) as ctx: + qs = Tour.objects.filter(exhibit__sections__number=1) + self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + self.assertNotIn("anyElementTrue", ctx.captured_queries[0]["sql"]) def test_foreign_field_exact_expr(self): """Querying from a foreign key to an EmbeddedModelArrayField.""" From 791f329e010b9c82f2c6bb57d79f82cee35b474f Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Mon, 29 Sep 2025 23:31:58 -0300 Subject: [PATCH 12/23] Add mql check in EMF and EMFA unit test. --- django_mongodb_backend/base.py | 10 +- django_mongodb_backend/test.py | 4 +- tests/model_fields_/test_embedded_model.py | 359 +++++++++++++++++- .../test_embedded_model_array.py | 225 ++++++++++- 4 files changed, 565 insertions(+), 33 deletions(-) diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 594fb559e..c8c15ebc1 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -2,6 +2,7 @@ import logging import os +from bson import Decimal128 from django.core.exceptions import EmptyResultSet, FullResultSet, ImproperlyConfigured from django.db import DEFAULT_DB_ALIAS from django.db.backends.base.base import BaseDatabaseWrapper @@ -149,8 +150,13 @@ def range_match(a, b): conditions.append({a: {"$gte": b[0]}}) if end is not None: conditions.append({a: {"$lte": b[1]}}) - if start is not None and end is not None and start > end: - raise EmptyResultSet + if start is not None and end is not None: + if isinstance(start, Decimal128): + start = start.to_decimal() + if isinstance(end, Decimal128): + end = end.to_decimal() + if start > end: + raise EmptyResultSet if not conditions: raise FullResultSet return {"$and": conditions} diff --git a/django_mongodb_backend/test.py b/django_mongodb_backend/test.py index 561832a15..ee35b4e21 100644 --- a/django_mongodb_backend/test.py +++ b/django_mongodb_backend/test.py @@ -1,6 +1,6 @@ """Not a public API.""" -from bson import SON, ObjectId +from bson import SON, Decimal128, ObjectId class MongoTestCaseMixin: @@ -16,6 +16,6 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline): self.assertEqual(operator, "aggregate") self.assertEqual(collection, expected_collection) self.assertEqual( - eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId}, {}), # noqa: S307 + eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307 expected_pipeline, ) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 5a3bb578f..e541c3935 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -10,12 +10,14 @@ Max, OuterRef, Sum, + Value, ) from django.test import SimpleTestCase, TestCase from django.test.utils import isolate_apps from django_mongodb_backend.fields import EmbeddedModelField from django_mongodb_backend.models import EmbeddedModel +from django_mongodb_backend.test import MongoTestCaseMixin from .models import ( Address, @@ -127,7 +129,7 @@ def test_embedded_model_field_respects_db_column(self): self.assertEqual(query[0]["data"]["integer_"], 5) -class QueryingTests(TestCase): +class QueryingTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.objs = [ @@ -141,23 +143,354 @@ def setUpTestData(cls): for x in range(6) ] - def test_exact(self): - self.assertCountEqual(Holder.objects.filter(data__integer=3), [self.objs[3]]) + def test_exact_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer=Value(4) - 1), [self.objs[3]]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$eq": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_exact_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer=3), [self.objs[3]]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery(query, "model_fields__holder", [{"$match": {"data.integer": 3}}]) + + def test_lt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__lt=Value(4) - 1), self.objs[:3] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$lt": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + }, + { + "$not": { + "$or": [ + { + "$eq": [ + { + "$type": { + "$getField": { + "input": "$data", + "field": "integer", + } + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$getField": { + "input": "$data", + "field": "integer", + } + }, + None, + ] + }, + ] + } + }, + ] + } + } + } + ], + ) + + def test_lt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__lt=3), self.objs[:3]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$and": [ + {"data.integer": {"$lt": 3}}, + { + "$and": [ + {"data.integer": {"$exists": True}}, + {"data.integer": {"$ne": None}}, + ] + }, + ] + } + } + ], + ) + + def test_lte_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__lte=Value(4) - 1), self.objs[:4] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$lte": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + }, + { + "$not": { + "$or": [ + { + "$eq": [ + { + "$type": { + "$getField": { + "input": "$data", + "field": "integer", + } + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$getField": { + "input": "$data", + "field": "integer", + } + }, + None, + ] + }, + ] + } + }, + ] + } + } + } + ], + ) + + def test_lte_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__lte=3), self.objs[:4]) + query = ctx.captured_queries[0]["sql"] + + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$and": [ + {"data.integer": {"$lte": 3}}, + { + "$and": [ + {"data.integer": {"$exists": True}}, + {"data.integer": {"$ne": None}}, + ] + }, + ] + } + } + ], + ) - def test_lt(self): - self.assertCountEqual(Holder.objects.filter(data__integer__lt=3), self.objs[:3]) + def test_gt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__gt=Value(4) - 1), self.objs[4:] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$gt": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_gt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__gt=3), self.objs[4:]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__holder", [{"$match": {"data.integer": {"$gt": 3}}}] + ) - def test_lte(self): - self.assertCountEqual(Holder.objects.filter(data__integer__lte=3), self.objs[:4]) + def test_gte_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__gte=Value(4) - 1), self.objs[3:] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$gte": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) - def test_gt(self): - self.assertCountEqual(Holder.objects.filter(data__integer__gt=3), self.objs[4:]) + def test_gte_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__holder", [{"$match": {"data.integer": {"$gte": 3}}}] + ) - def test_gte(self): - self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) + def test_range_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__range=(2, Value(5) - 1)), self.objs[2:5] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$or": [ + { + "$or": [ + {"$eq": [{"$type": {"$literal": 2}}, "missing"]}, + {"$eq": [{"$literal": 2}, None]}, + ] + }, + { + "$gte": [ + { + "$getField": { + "input": "$data", + "field": "integer", + } + }, + {"$literal": 2}, + ] + }, + ] + }, + { + "$or": [ + { + "$or": [ + { + "$eq": [ + { + "$type": { + "$subtract": [ + {"$literal": 5}, + {"$literal": 1}, + ] + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$subtract": [ + {"$literal": 5}, + {"$literal": 1}, + ] + }, + None, + ] + }, + ] + }, + { + "$lte": [ + { + "$getField": { + "input": "$data", + "field": "integer", + } + }, + {"$subtract": [{"$literal": 5}, {"$literal": 1}]}, + ] + }, + ] + }, + ] + } + } + } + ], + ) - def test_range(self): - self.assertCountEqual(Holder.objects.filter(data__integer__range=(2, 4)), self.objs[2:5]) + def test_range_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__range=(2, 4)), self.objs[2:5] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [{"$match": {"$and": [{"data.integer": {"$gte": 2}}, {"data.integer": {"$lte": 4}}]}}], + ) def test_exact_decimal(self): # EmbeddedModelField lookups call diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index c038f29ac..e161fa051 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -10,6 +10,7 @@ from django_mongodb_backend.fields import ArrayField, EmbeddedModelArrayField from django_mongodb_backend.models import EmbeddedModel +from django_mongodb_backend.test import MongoTestCaseMixin from .models import Artifact, Audit, Exhibit, Movie, Restoration, Review, Section, Tour @@ -84,7 +85,7 @@ def test_embedded_model_field_respects_db_column(self): self.assertEqual(query[0]["reviews"][0]["title_"], "Awesome") -class QueryingTests(TestCase): +class QueryingTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.egypt = Exhibit.objects.create( @@ -177,23 +178,203 @@ def setUpTestData(cls): cls.audit_2 = Audit.objects.create(section_number=2, reviewed=True) cls.audit_3 = Audit.objects.create(section_number=5, reviewed=False) - def test_exact(self): - self.assertCountEqual( - Exhibit.objects.filter(sections__number=1), [self.egypt, self.wonders] + def test_exact_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__number=Value(2) - 1), [self.egypt, self.wonders] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": "$sections", + "as": "item", + "in": { + "$eq": [ + "$$item.number", + { + "$subtract": [ + {"$literal": 2}, + {"$literal": 1}, + ] + }, + ] + }, + } + }, + [], + ] + } + } + } + } + ], ) - def test_array_index(self): - self.assertCountEqual( - Exhibit.objects.filter(sections__0__number=1), - [self.egypt, self.wonders], + def test_exact_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__number=1), [self.egypt, self.wonders] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__exhibit", [{"$match": {"sections.number": 1}}] ) - def test_nested_array_index(self): - self.assertCountEqual( - Exhibit.objects.filter( - main_section__artifacts__restorations__0__restored_by="Zacarias" - ), - [self.lost_empires], + def test_array_index_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__0__number=Value(2) - 1), + [self.egypt, self.wonders], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$eq": [ + { + "$getField": { + "input": {"$arrayElemAt": ["$sections", 0]}, + "field": "number", + } + }, + {"$subtract": [{"$literal": 2}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_array_index_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__0__number=1), + [self.egypt, self.wonders], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__exhibit", [{"$match": {"sections.0.number": 1}}] + ) + + def test_nested_array_index_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by="Zacarias" + ), + [self.lost_empires], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": { + "$getField": { + "input": "$main_section", + "field": "artifacts", + } + }, + "as": "item", + "in": { + "$eq": [ + { + "$getField": { + "input": { + "$arrayElemAt": [ + "$$item.restorations", + 0, + ] + }, + "field": "restored_by", + } + }, + "Zacarias", + ] + }, + } + }, + [], + ] + } + } + } + } + ], + ) + + def test_nested_array_index_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by="Zacarias" + ), + [self.lost_empires], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": { + "$getField": { + "input": "$main_section", + "field": "artifacts", + } + }, + "as": "item", + "in": { + "$eq": [ + { + "$getField": { + "input": { + "$arrayElemAt": [ + "$$item.restorations", + 0, + ] + }, + "field": "restored_by", + } + }, + "Zacarias", + ] + }, + } + }, + [], + ] + } + } + } + } + ], ) def test_array_slice(self): @@ -207,8 +388,20 @@ def test_filter_unsupported_lookups_in_json(self): kwargs = {f"main_section__artifacts__metadata__origin__{lookup}": ["Pergamon", "Egypt"]} with CaptureQueriesContext(connection) as captured_queries: self.assertCountEqual(Exhibit.objects.filter(**kwargs), []) - self.assertIn( - f"'main_section.artifacts.metadata.origin.{lookup}':", captured_queries[0]["sql"] + query = captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + f"main_section.artifacts.metadata.origin.{lookup}": [ + "Pergamon", + "Egypt", + ] + } + } + ], ) def test_len(self): From bddd5cd5eb65865dd4a143aa81d84c5763e35716 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 30 Sep 2025 00:29:00 -0300 Subject: [PATCH 13/23] Rename method. --- django_mongodb_backend/expressions/builtins.py | 6 +----- django_mongodb_backend/fields/array.py | 3 ++- django_mongodb_backend/fields/embedded_model.py | 3 ++- django_mongodb_backend/fields/embedded_model_array.py | 3 ++- django_mongodb_backend/fields/json.py | 10 ++++++---- django_mongodb_backend/functions.py | 6 +----- django_mongodb_backend/lookups.py | 5 +++-- 7 files changed, 17 insertions(+), 19 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 1556dc9c6..62ee5a67d 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -30,11 +30,7 @@ def base_expression(self, compiler, connection, as_path=False, **extra): - if ( - as_path - and hasattr(self, "as_mql_path") - and getattr(self, "is_simple_expression", lambda: False)() - ): + if as_path and hasattr(self, "as_mql_path") and getattr(self, "can_use_path", False): return self.as_mql_path(compiler, connection, **extra) expr = self.as_mql_expr(compiler, connection, **extra) diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 3ffb85f75..da83d7a66 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -381,7 +381,8 @@ def __init__(self, index, base_field, *args, **kwargs): self.index = index self.base_field = base_field - def is_simple_expression(self): + @property + def can_use_path(self): return self.is_simple_column @property diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 95896a0f9..f11478832 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -167,7 +167,8 @@ def __init__(self, field, *args, **kwargs): def get_lookup(self, name): return self.field.get_lookup(name) - def is_simple_expression(self): + @property + def can_use_path(self): return self.is_simple_column @cached_property diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index 228f35869..70a1d05f7 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -247,7 +247,8 @@ def __call__(self, this, *args, **kwargs): self._lhs = self._sub_transform(self._lhs, *args, **kwargs) return self - def is_simple_expression(self): + @property + def can_use_path(self): return self.is_simple_column @cached_property diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 31b502faf..bb979875a 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -72,6 +72,7 @@ def _has_key_predicate(path, root_column=None, negated=False, as_path=False): return result +@property def has_key_check_simple_expression(self): rhs = [self.rhs] if not isinstance(self.rhs, (list, tuple)) else self.rhs return self.is_simple_column and all(key.isalnum() for key in rhs) @@ -147,7 +148,7 @@ def key_transform_in(self, compiler, connection, as_path=False): Return MQL to check if a JSON path exists and that its values are in the set of specified values (rhs). """ - if as_path and self.is_simple_expression(): + if as_path and self.can_use_path(): return builtin_lookup_path(self, compiler, connection) lhs_mql = process_lhs(self, compiler, connection) @@ -226,6 +227,7 @@ def key_transform_numeric_lookup_mixin_path(self, compiler, connection): return builtin_lookup_path(self, compiler, connection) +@property def keytransform_is_simple_column(self): previous = self while isinstance(previous, KeyTransform): @@ -242,11 +244,11 @@ def register_json_field(): HasKey.mongo_operator = None HasKeyLookup.as_mql_path = partialmethod(has_key_lookup, as_path=True) HasKeyLookup.as_mql_expr = partialmethod(has_key_lookup, as_path=False) - HasKeyLookup.is_simple_expression = has_key_check_simple_expression + HasKeyLookup.can_use_path = has_key_check_simple_expression HasKeys.mongo_operator = "$and" JSONExact.process_rhs = json_exact_process_rhs - KeyTransform.is_simple_column = property(keytransform_is_simple_column) - KeyTransform.is_simple_expression = keytransform_is_simple_column + KeyTransform.is_simple_column = keytransform_is_simple_column + KeyTransform.can_use_path = keytransform_is_simple_column KeyTransform.as_mql_path = partialmethod(key_transform, as_path=True) KeyTransform.as_mql_expr = partialmethod(key_transform, as_path=False) KeyTransformExact.as_mql_expr = key_transform_exact_expr diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index 56a027b7b..417a427d8 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -292,10 +292,6 @@ def trunc_time(self, compiler, connection): } -def is_simple_expression(self): # noqa: ARG001 - return False - - def register_functions(): Cast.as_mql_expr = cast Concat.as_mql_expr = concat @@ -323,4 +319,4 @@ def register_functions(): TruncDate.as_mql_expr = trunc_date TruncTime.as_mql_expr = trunc_time Upper.as_mql_expr = preserve_null("toUpper") - Func.is_simple_expression = is_simple_expression + Func.can_use_path = False diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 83d0ea18f..e5154fad4 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -138,14 +138,15 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("Pattern lookups on UUIDField are not supported.") -def is_simple_expression(self): +@property +def can_use_path(self): simple_column = getattr(self.lhs, "is_simple_column", False) constant_value = is_constant_value(self.rhs) return simple_column and constant_value def register_lookups(): - Lookup.is_simple_expression = is_simple_expression + Lookup.can_use_path = can_use_path BuiltinLookup.as_mql_path = builtin_lookup_path BuiltinLookup.as_mql_expr = builtin_lookup_expr FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = ( From 0eae854f211bdc9522ea06eb5df3b91829c7f4ec Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 30 Sep 2025 00:37:47 -0300 Subject: [PATCH 14/23] Clean ups. --- django_mongodb_backend/expressions/builtins.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 62ee5a67d..a5db3d70b 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -42,7 +42,7 @@ def case(self, compiler, connection): for case in self.cases: case_mql = {} try: - case_mql["case"] = case.as_mql(compiler, connection, as_path=False) + case_mql["case"] = case.as_mql(compiler, connection) except EmptyResultSet: continue except FullResultSet: @@ -92,16 +92,16 @@ def col_pairs(self, compiler, connection, as_path=False): return cols[0].as_mql(compiler, connection, as_path=as_path) -def combined_expression(self, compiler, connection, as_path=False): +def combined_expression(self, compiler, connection): expressions = [ - self.lhs.as_mql(compiler, connection, as_path=as_path), - self.rhs.as_mql(compiler, connection, as_path=as_path), + self.lhs.as_mql(compiler, connection), + self.rhs.as_mql(compiler, connection), ] return connection.ops.combine_expression(self.connector, expressions) def expression_wrapper_expr(self, compiler, connection): - return self.expression.as_mql(compiler, connection, as_path=False) + return self.expression.as_mql(compiler, connection) def negated_expression_expr(self, compiler, connection): @@ -183,7 +183,7 @@ def ref_is_simple_column(self): return self.source.is_simple_column -def star(self, compiler, connection, as_path=False): # noqa: ARG001 +def star(self, compiler, connection): # noqa: ARG001 return {"$literal": True} @@ -247,7 +247,7 @@ def register_expressions(): Ref.as_mql = ref Ref.is_simple_column = ref_is_simple_column ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql - Star.as_mql = star + Star.as_mql_expr = star Subquery.as_mql_expr = subquery When.as_mql = when Value.as_mql = value From 111b78c28cf21aa33a0f2369aaeda82a6900a16d Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Tue, 30 Sep 2025 00:48:22 -0300 Subject: [PATCH 15/23] Clean up. --- django_mongodb_backend/fields/array.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index da83d7a66..14077b1f2 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -268,8 +268,6 @@ def as_mql_expr(self, compiler, connection): def as_mql_path(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=True) value = process_rhs(self, compiler, connection, as_path=True) - if value is None: - return False return {lhs_mql: {"$all": value}} From a344004b88cb5fe5793cf926511338ea902348d4 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 3 Oct 2025 00:29:05 -0300 Subject: [PATCH 16/23] Handle array as path and update unit test. --- django_mongodb_backend/fields/array.py | 17 +- .../fields/embedded_model.py | 5 +- .../fields/embedded_model_array.py | 5 +- django_mongodb_backend/fields/json.py | 6 +- django_mongodb_backend/query_utils.py | 35 +- tests/model_fields_/test_arrayfield.py | 323 +++++++++++++++--- tests/model_fields_/test_embedded_model.py | 48 +-- .../test_embedded_model_array.py | 53 +-- 8 files changed, 370 insertions(+), 122 deletions(-) diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 14077b1f2..05fc97ffc 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -4,10 +4,11 @@ from django.db.models import Field, Func, IntegerField, Transform, Value from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup +from django.utils.functional import cached_property from django.utils.translation import gettext_lazy as _ from ..forms import SimpleArrayField -from ..query_utils import process_lhs, process_rhs +from ..query_utils import is_constant_value, process_lhs, process_rhs from ..utils import prefix_validation_error from ..validators import ArrayMaxLengthValidator, LengthValidator @@ -236,6 +237,20 @@ def as_mql_expr(self, compiler, connection): for expr in self.get_source_expressions() ] + def as_mql_path(self, compiler, connection): + return [ + expr.as_mql(compiler, connection, as_path=True) + for expr in self.get_source_expressions() + ] + + @cached_property + def can_use_path(self): + return all(is_constant_value(expr) for expr in self.get_source_expressions()) + + @property + def is_simple_column(self): + return False + class ArrayRHSMixin: def __init__(self, lhs, rhs): diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index f11478832..8f4ea1fc8 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -8,6 +8,7 @@ from django.utils.functional import cached_property from .. import forms +from ..query_utils import valid_path_key_name class EmbeddedModelField(models.Field): @@ -174,8 +175,8 @@ def can_use_path(self): @cached_property def is_simple_column(self): previous = self - while isinstance(previous, KeyTransform): - if not previous.key_name.isalnum(): + while isinstance(previous, EmbeddedModelTransform): + if not valid_path_key_name(previous._field.column): return False previous = previous.lhs return previous.is_simple_column diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index 70a1d05f7..3238ccc3c 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -9,7 +9,7 @@ from .. import forms from ..lookups import builtin_lookup_path -from ..query_utils import process_lhs, process_rhs +from ..query_utils import process_lhs, process_rhs, valid_path_key_name from . import EmbeddedModelField from .array import ArrayField, ArrayLenTransform @@ -240,6 +240,7 @@ def __init__(self, field, *args, **kwargs): column_name = f"$item.{field.column}" column_target.db_column = column_name column_target.set_attributes_from_name(column_name) + self._field = field self._lhs = Col(None, column_target) self._sub_transform = None @@ -255,7 +256,7 @@ def can_use_path(self): def is_simple_column(self): previous = self while isinstance(previous, EmbeddedModelArrayFieldTransform): - if not previous.key_name.isalnum(): + if not valid_path_key_name(previous._field.column): return False previous = previous.lhs return previous.is_simple_column and self._lhs.is_simple_column diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index bb979875a..139f138b4 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -18,7 +18,7 @@ ) from ..lookups import builtin_lookup_expr, builtin_lookup_path -from ..query_utils import process_lhs, process_rhs +from ..query_utils import process_lhs, process_rhs, valid_path_key_name def build_json_mql_path(lhs, key_transforms, as_path=False): @@ -75,7 +75,7 @@ def _has_key_predicate(path, root_column=None, negated=False, as_path=False): @property def has_key_check_simple_expression(self): rhs = [self.rhs] if not isinstance(self.rhs, (list, tuple)) else self.rhs - return self.is_simple_column and all(key.isalnum() for key in rhs) + return self.is_simple_column and all(valid_path_key_name(key) for key in rhs) def has_key_lookup(self, compiler, connection, as_path=False): @@ -231,7 +231,7 @@ def key_transform_numeric_lookup_mixin_path(self, compiler, connection): def keytransform_is_simple_column(self): previous = self while isinstance(previous, KeyTransform): - if not previous.key_name.isalnum(): + if not valid_path_key_name(previous.key_name): return False previous = previous.lhs return previous.is_simple_column diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 75c3c099f..43f361698 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -1,6 +1,8 @@ +import re + from django.core.exceptions import FullResultSet from django.db.models.aggregates import Aggregate -from django.db.models.expressions import CombinedExpression, Value +from django.db.models.expressions import CombinedExpression, Func, Value from django.db.models.sql.query import Query @@ -74,17 +76,24 @@ def is_constant_value(value): if hasattr(value, "get_source_expressions"): # Temporary: similar limitation as above, sub-expressions should be # resolved in the future - simple_sub_expressions = all(map(is_constant_value, value.get_source_expressions())) + constants_sub_expressions = all(map(is_constant_value, value.get_source_expressions())) else: - simple_sub_expressions = True - return ( - simple_sub_expressions - and isinstance(value, Value) - and not ( - isinstance(value, Query) - or value.contains_aggregate - or value.contains_over_clause - or value.contains_column_references - or value.contains_subquery - ) + constants_sub_expressions = True + constants_sub_expressions = constants_sub_expressions and not ( + isinstance(value, Query) + or value.contains_aggregate + or value.contains_over_clause + or value.contains_column_references + or value.contains_subquery + ) + return constants_sub_expressions and ( + isinstance(value, Value) + or + # Some closed functions cannot yet be converted to constant values. + # Allow Func with can_use_path as a temporary exception. + (isinstance(value, Func) and value.can_use_path) ) + + +def valid_path_key_name(key_name): + return bool(re.fullmatch(r"[A-Za-z0-9_]+", key_name)) diff --git a/tests/model_fields_/test_arrayfield.py b/tests/model_fields_/test_arrayfield.py index 06a918ebc..00fdf663e 100644 --- a/tests/model_fields_/test_arrayfield.py +++ b/tests/model_fields_/test_arrayfield.py @@ -21,6 +21,8 @@ from django.utils import timezone from django_mongodb_backend.fields import ArrayField +from django_mongodb_backend.fields.array import Array +from django_mongodb_backend.test import MongoTestCaseMixin from .models import ( ArrayEnumModel, @@ -216,7 +218,7 @@ def test_nested_nullable_base_field(self): self.assertEqual(instance.field_nested, [[None, None], [None, None]]) -class QueryingTests(TestCase): +class QueryingTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.objs = NullableIntegerArrayModel.objects.bulk_create( @@ -241,9 +243,34 @@ def test_empty_list(self): self.assertEqual(obj.field, []) self.assertEqual(obj.empty_array, []) - def test_exact(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1] + def test_exact_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[Value(3) / 3]), self.objs[:1] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$eq": ["$field", [{"$divide": [{"$literal": 3}, {"$literal": 3}]}]] + } + } + } + ], + ) + + def test_exact_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__nullableintegerarraymodel", [{"$match": {"field": [1]}}] ) def test_exact_null_only_array(self): @@ -261,23 +288,42 @@ def test_exact_null_only_nested_array(self): obj2 = NullableIntegerArrayModel.objects.create( field_nested=[[None, None], [None, None]], ) - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter( - field_nested__exact=[[None, None]], - ), - [obj1], + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field_nested__exact=[[None, None]], + ), + [obj1], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"field_nested": [[None, None]]}}], ) - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter( - field_nested__exact=[[None, None], [None, None]], - ), - [obj2], + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field_nested__exact=[[None, None], [None, None]], + ), + [obj2], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"field_nested": [[None, None], [None, None]]}}], ) def test_exact_with_expression(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]), - self.objs[:1], + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]), + self.objs[:1], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__nullableintegerarraymodel", [{"$match": {"field": [1]}}] ) def test_exact_charfield(self): @@ -291,24 +337,140 @@ def test_exact_nested(self): ) def test_isnull(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:] + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"$or": [{"field": {"$exists": False}}, {"field": None}]}}], ) - def test_gt(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4] + def test_gt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__gt=Array(Value(0) * 3)), + self.objs[:4], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$gt": ["$field", [{"$multiply": [{"$literal": 0}, {"$literal": 3}]}]] + } + } + } + ], ) - def test_lt(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1] + def test_gt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__gt=Array(0)), self.objs[:4] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__nullableintegerarraymodel", [{"$match": {"field": {"$gt": [0]}}}] ) - def test_in(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), - self.objs[:2], + def test_lt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__lt=Array(Value(1) + 1)), + self.objs[:1], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$and": [ + {"$lt": ["$field", [{"$add": [{"$literal": 1}, {"$literal": 1}]}]]}, + { + "$not": { + "$or": [ + {"$eq": [{"$type": "$field"}, "missing"]}, + {"$eq": ["$field", None]}, + ] + } + }, + ] + } + } + } + ], + ) + + def test_lt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$and": [ + {"field": {"$lt": [2]}}, + {"$and": [{"field": {"$exists": True}}, {"field": {"$ne": None}}]}, + ] + } + } + ], + ) + + def test_in_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field__in=Array(Array(Value(1) * 1), Array(2)) + ), + self.objs[:2], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$in": ( + "$field", + [ + [{"$multiply": [{"$literal": 1}, {"$literal": 1}]}], + [{"$literal": 2}], + ], + ) + } + } + } + ], + ) + + def test_in_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), + self.objs[:2], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"field": {"$in": ([1], [2])}}}], ) def test_in_subquery(self): @@ -352,10 +514,45 @@ def test_contained_by_including_F_object(self): self.objs[:3], ) - def test_contains(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__contains=[2]), - self.objs[1:3], + def test_contains_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contains=[Value(1) + 1]), + self.objs[1:3], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$and": [ + {"$ne": ["$field", None]}, + {"$ne": [[{"$add": [{"$literal": 1}, {"$literal": 1}]}], None]}, + { + "$setIsSubset": [ + [{"$add": [{"$literal": 1}, {"$literal": 1}]}], + "$field", + ] + }, + ] + } + } + } + ], + ) + + def test_contains_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contains=[2]), + self.objs[1:3], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__nullableintegerarraymodel", [{"$match": {"field": {"$all": [2]}}}] ) def test_contains_subquery(self): @@ -395,7 +592,16 @@ def test_contains_including_expression(self): def test_icontains(self): instance = CharArrayModel.objects.create(field=["FoO"]) - self.assertSequenceEqual(CharArrayModel.objects.filter(field__icontains="foo"), [instance]) + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + CharArrayModel.objects.filter(field__icontains="foo"), [instance] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__chararraymodel", + [{"$match": {"field": {"$regex": "foo", "$options": "i"}}}], + ) def test_contains_charfield(self): self.assertSequenceEqual(CharArrayModel.objects.filter(field__contains=["text"]), []) @@ -455,10 +661,51 @@ def test_index_used_on_nested_data(self): NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance] ) - def test_overlap(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), - self.objs[0:3], + def test_overlap_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__overlap=[1, Value(1) + 1]), + self.objs[0:3], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$and": [ + {"$ne": ["$field", None]}, + { + "$size": { + "$setIntersection": [ + [ + {"$literal": 1}, + {"$add": [{"$literal": 1}, {"$literal": 1}]}, + ], + "$field", + ] + } + }, + ] + } + } + } + ], + ) + + def test_overlap_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), + self.objs[0:3], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"field": {"$in": [1, 2]}}}], ) def test_index_annotation(self): diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index e541c3935..4a8005365 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -155,7 +155,7 @@ def test_exact_expr(self): "$match": { "$expr": { "$eq": [ - {"$getField": {"input": "$data", "field": "integer"}}, + {"$getField": {"input": "$data", "field": "integer_"}}, {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, ] } @@ -168,7 +168,7 @@ def test_exact_path(self): with self.assertNumQueries(1) as ctx: self.assertCountEqual(Holder.objects.filter(data__integer=3), [self.objs[3]]) query = ctx.captured_queries[0]["sql"] - self.assertAggregateQuery(query, "model_fields__holder", [{"$match": {"data.integer": 3}}]) + self.assertAggregateQuery(query, "model_fields__holder", [{"$match": {"data.integer_": 3}}]) def test_lt_expr(self): with self.assertNumQueries(1) as ctx: @@ -186,7 +186,7 @@ def test_lt_expr(self): "$and": [ { "$lt": [ - {"$getField": {"input": "$data", "field": "integer"}}, + {"$getField": {"input": "$data", "field": "integer_"}}, {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, ] }, @@ -199,7 +199,7 @@ def test_lt_expr(self): "$type": { "$getField": { "input": "$data", - "field": "integer", + "field": "integer_", } } }, @@ -211,7 +211,7 @@ def test_lt_expr(self): { "$getField": { "input": "$data", - "field": "integer", + "field": "integer_", } }, None, @@ -238,11 +238,11 @@ def test_lt_path(self): { "$match": { "$and": [ - {"data.integer": {"$lt": 3}}, + {"data.integer_": {"$lt": 3}}, { "$and": [ - {"data.integer": {"$exists": True}}, - {"data.integer": {"$ne": None}}, + {"data.integer_": {"$exists": True}}, + {"data.integer_": {"$ne": None}}, ] }, ] @@ -267,7 +267,7 @@ def test_lte_expr(self): "$and": [ { "$lte": [ - {"$getField": {"input": "$data", "field": "integer"}}, + {"$getField": {"input": "$data", "field": "integer_"}}, {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, ] }, @@ -280,7 +280,7 @@ def test_lte_expr(self): "$type": { "$getField": { "input": "$data", - "field": "integer", + "field": "integer_", } } }, @@ -292,7 +292,7 @@ def test_lte_expr(self): { "$getField": { "input": "$data", - "field": "integer", + "field": "integer_", } }, None, @@ -320,11 +320,11 @@ def test_lte_path(self): { "$match": { "$and": [ - {"data.integer": {"$lte": 3}}, + {"data.integer_": {"$lte": 3}}, { "$and": [ - {"data.integer": {"$exists": True}}, - {"data.integer": {"$ne": None}}, + {"data.integer_": {"$exists": True}}, + {"data.integer_": {"$ne": None}}, ] }, ] @@ -347,7 +347,7 @@ def test_gt_expr(self): "$match": { "$expr": { "$gt": [ - {"$getField": {"input": "$data", "field": "integer"}}, + {"$getField": {"input": "$data", "field": "integer_"}}, {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, ] } @@ -361,7 +361,7 @@ def test_gt_path(self): self.assertCountEqual(Holder.objects.filter(data__integer__gt=3), self.objs[4:]) query = ctx.captured_queries[0]["sql"] self.assertAggregateQuery( - query, "model_fields__holder", [{"$match": {"data.integer": {"$gt": 3}}}] + query, "model_fields__holder", [{"$match": {"data.integer_": {"$gt": 3}}}] ) def test_gte_expr(self): @@ -378,7 +378,7 @@ def test_gte_expr(self): "$match": { "$expr": { "$gte": [ - {"$getField": {"input": "$data", "field": "integer"}}, + {"$getField": {"input": "$data", "field": "integer_"}}, {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, ] } @@ -392,7 +392,7 @@ def test_gte_path(self): self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) query = ctx.captured_queries[0]["sql"] self.assertAggregateQuery( - query, "model_fields__holder", [{"$match": {"data.integer": {"$gte": 3}}}] + query, "model_fields__holder", [{"$match": {"data.integer_": {"$gte": 3}}}] ) def test_range_expr(self): @@ -422,7 +422,7 @@ def test_range_expr(self): { "$getField": { "input": "$data", - "field": "integer", + "field": "integer_", } }, {"$literal": 2}, @@ -465,7 +465,7 @@ def test_range_expr(self): { "$getField": { "input": "$data", - "field": "integer", + "field": "integer_", } }, {"$subtract": [{"$literal": 5}, {"$literal": 1}]}, @@ -489,7 +489,13 @@ def test_range_path(self): self.assertAggregateQuery( query, "model_fields__holder", - [{"$match": {"$and": [{"data.integer": {"$gte": 2}}, {"data.integer": {"$lte": 4}}]}}], + [ + { + "$match": { + "$and": [{"data.integer_": {"$gte": 2}}, {"data.integer_": {"$lte": 4}}] + } + } + ], ) def test_exact_decimal(self): diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index e161fa051..837fa2b94 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -5,6 +5,7 @@ from django.core.exceptions import FieldDoesNotExist from django.db import connection, models from django.db.models.expressions import Value +from django.db.models.functions import Concat from django.test import SimpleTestCase, TestCase from django.test.utils import CaptureQueriesContext, isolate_apps @@ -273,7 +274,9 @@ def test_nested_array_index_expr(self): with self.assertNumQueries(1) as ctx: self.assertCountEqual( Exhibit.objects.filter( - main_section__artifacts__restorations__0__restored_by="Zacarias" + main_section__artifacts__restorations__0__restored_by=Concat( + Value("Z"), Value("acarias") + ) ), [self.lost_empires], ) @@ -309,7 +312,12 @@ def test_nested_array_index_expr(self): "field": "restored_by", } }, - "Zacarias", + { + "$concat": [ + {"$ifNull": ["Z", ""]}, + {"$ifNull": ["acarias", ""]}, + ] + }, ] }, } @@ -335,46 +343,7 @@ def test_nested_array_index_path(self): self.assertAggregateQuery( query, "model_fields__exhibit", - [ - { - "$match": { - "$expr": { - "$anyElementTrue": { - "$ifNull": [ - { - "$map": { - "input": { - "$getField": { - "input": "$main_section", - "field": "artifacts", - } - }, - "as": "item", - "in": { - "$eq": [ - { - "$getField": { - "input": { - "$arrayElemAt": [ - "$$item.restorations", - 0, - ] - }, - "field": "restored_by", - } - }, - "Zacarias", - ] - }, - } - }, - [], - ] - } - } - } - } - ], + [{"$match": {"main_section.artifacts.restorations.0.restored_by": "Zacarias"}}], ) def test_array_slice(self): From b8baa4edfc6143466a444abcfacf88c1ea7a392b Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 3 Oct 2025 22:46:16 -0300 Subject: [PATCH 17/23] Clean ups. --- django_mongodb_backend/aggregates.py | 4 +-- django_mongodb_backend/compiler.py | 4 +-- .../expressions/builtins.py | 16 +++++------ django_mongodb_backend/functions.py | 27 +++++++------------ 4 files changed, 22 insertions(+), 29 deletions(-) diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index 2d1dd6afe..6959f8b5f 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -64,12 +64,12 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co return {"$add": [{"$size": lhs_mql}, exits_null]} -def stddev_variance(self, compiler, connection, **extra_context): # noqa: ARG001 +def stddev_variance(self, compiler, connection, **extra_context): if self.function.endswith("_SAMP"): operator = "stdDevSamp" elif self.function.endswith("_POP"): operator = "stdDevPop" - return aggregate(self, compiler, connection, operator=operator) + return aggregate(self, compiler, connection, operator=operator, **extra_context) def register_aggregates(): diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 66d715258..8d3a3f3ba 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -481,11 +481,11 @@ def build_query(self, columns=None): query.lookup_pipeline = self.get_lookup_pipeline() where = self.get_where() try: - expr = where.as_mql(self, self.connection, as_path=True) if where else {} + match = where.as_mql(self, self.connection, as_path=True) if where else {} except FullResultSet: query.match_mql = {} else: - query.match_mql = expr + query.match_mql = match if extra_fields: query.extra_fields = self.get_project_fields(extra_fields, force_expression=True) query.subqueries = self.subqueries diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index a5db3d70b..5ddb6869a 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -100,12 +100,12 @@ def combined_expression(self, compiler, connection): return connection.ops.combine_expression(self.connector, expressions) -def expression_wrapper_expr(self, compiler, connection): +def expression_wrapper(self, compiler, connection): return self.expression.as_mql(compiler, connection) -def negated_expression_expr(self, compiler, connection): - return {"$not": expression_wrapper_expr(self, compiler, connection)} +def negated_expression(self, compiler, connection): + return {"$not": expression_wrapper(self, compiler, connection)} def order_by(self, compiler, connection): @@ -201,8 +201,8 @@ def exists(self, compiler, connection, get_wrapping_pipeline=None): return connection.mongo_expr_operators["isnull"](lhs_mql, False) -def when(self, compiler, connection, as_path=False): - return self.condition.as_mql(compiler, connection, as_path=as_path) +def when(self, compiler, connection): + return self.condition.as_mql(compiler, connection) def value(self, compiler, connection, as_path=False): # noqa: ARG001 @@ -239,8 +239,8 @@ def register_expressions(): CombinedExpression.as_mql_expr = combined_expression Exists.as_mql_expr = exists ExpressionList.as_mql = process_lhs - ExpressionWrapper.as_mql_expr = expression_wrapper_expr - NegatedExpression.as_mql_expr = negated_expression_expr + ExpressionWrapper.as_mql_expr = expression_wrapper + NegatedExpression.as_mql_expr = negated_expression OrderBy.as_mql_expr = order_by Query.as_mql = query RawSQL.as_mql = raw_sql @@ -249,5 +249,5 @@ def register_expressions(): ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql Star.as_mql_expr = star Subquery.as_mql_expr = subquery - When.as_mql = when + When.as_mql_expr = when Value.as_mql = value diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index 417a427d8..72958b489 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -128,12 +128,12 @@ def func_expr(self, compiler, connection): def left(self, compiler, connection): - return self.get_substr().as_mql(compiler, connection, as_path=False) + return self.get_substr().as_mql(compiler, connection) def length(self, compiler, connection): # Check for null first since $strLenCP only accepts strings. - lhs_mql = process_lhs(self, compiler, connection, as_path=False) + lhs_mql = process_lhs(self, compiler, connection) return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}} @@ -141,7 +141,7 @@ def log(self, compiler, connection): # This function is usually log(base, num) but on MongoDB it's log(num, base). clone = self.copy() clone.set_source_expressions(self.get_source_expressions()[::-1]) - return func(clone, compiler, connection, as_path=False) + return func(clone, compiler, connection) def now(self, compiler, connection): # noqa: ARG001 @@ -150,9 +150,7 @@ def now(self, compiler, connection): # noqa: ARG001 def null_if(self, compiler, connection): """Return None if expr1==expr2 else expr1.""" - expr1, expr2 = ( - expr.as_mql(compiler, connection, as_path=False) for expr in self.get_source_expressions() - ) + expr1, expr2 = (expr.as_mql(compiler, connection) for expr in self.get_source_expressions()) return {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}} @@ -160,7 +158,7 @@ def preserve_null(operator): # If the argument is null, the function should return null, not # $toLower/Upper's behavior of returning an empty string. def wrapped(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection, as_path=False) + lhs_mql = process_lhs(self, compiler, connection) return { "$cond": { "if": connection.mongo_expr_operators["isnull"](lhs_mql, True), @@ -173,23 +171,18 @@ def wrapped(self, compiler, connection): def replace(self, compiler, connection): - expression, text, replacement = process_lhs(self, compiler, connection, as_path=False) + expression, text, replacement = process_lhs(self, compiler, connection) return {"$replaceAll": {"input": expression, "find": text, "replacement": replacement}} def round_(self, compiler, connection): # Round needs its own function because it's a special case that inherits # from Transform but has two arguments. - return { - "$round": [ - expr.as_mql(compiler, connection, as_path=False) - for expr in self.get_source_expressions() - ] - } + return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]} def str_index(self, compiler, connection): - lhs = process_lhs(self, compiler, connection, as_path=False) + lhs = process_lhs(self, compiler, connection) # StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB. return {"$add": [{"$indexOfCP": lhs}, 1]} @@ -253,7 +246,7 @@ def trunc_convert_value(self, value, expression, connection): def trunc_date(self, compiler, connection): # Cast to date rather than truncate to date. - lhs_mql = process_lhs(self, compiler, connection, as_path=False) + lhs_mql = process_lhs(self, compiler, connection) tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncDate with tzinfo ({tzname}) isn't supported on MongoDB.") @@ -276,7 +269,7 @@ def trunc_time(self, compiler, connection): tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncTime with tzinfo ({tzname}) isn't supported on MongoDB.") - lhs_mql = process_lhs(self, compiler, connection, as_path=False) + lhs_mql = process_lhs(self, compiler, connection) return { "$dateFromString": { "dateString": { From b0068d4da2a1a545e5991ff8280027550265582a Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Fri, 3 Oct 2025 23:41:00 -0300 Subject: [PATCH 18/23] More clean ups. --- django_mongodb_backend/fields/json.py | 7 ++---- django_mongodb_backend/functions.py | 35 ++++++++------------------- django_mongodb_backend/lookups.py | 28 ++++++++++----------- django_mongodb_backend/query.py | 4 +-- 4 files changed, 28 insertions(+), 46 deletions(-) diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 139f138b4..047fc01c7 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -54,10 +54,7 @@ def data_contains(self, compiler, connection, as_path=False): # noqa: ARG001 def _has_key_predicate(path, root_column=None, negated=False, as_path=False): """Return MQL to check for the existence of `path`.""" if as_path: - # if not negated: return {path: {"$exists": not negated}} - # return {"$and": [{path: {"$exists": True}}, {path: {"$ne": None}}]} - # return {"$or": [{path: {"$exists": False}}, {path: None}]} result = { "$and": [ # The path must exist (i.e. not be "missing"). @@ -81,9 +78,9 @@ def has_key_check_simple_expression(self): def has_key_lookup(self, compiler, connection, as_path=False): """Return MQL to check for the existence of a key.""" rhs = self.rhs + lhs = process_lhs(self, compiler, connection, as_path=as_path) if not isinstance(rhs, (list, tuple)): rhs = [rhs] - lhs = process_lhs(self, compiler, connection, as_path=as_path) paths = [] # Transform any "raw" keys into KeyTransforms to allow consistent handling # in the code that follows. @@ -198,7 +195,7 @@ def key_transform_is_null_expr(self, compiler, connection): previous = previous.lhs root_column = previous.as_mql(compiler, connection) lhs_mql = process_lhs(self, compiler, connection, as_path=False) - rhs_mql = process_rhs(self, compiler, connection) + rhs_mql = process_rhs(self, compiler, connection, as_path=False) return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index 72958b489..a01a28464 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -67,7 +67,7 @@ def cast(self, compiler, connection): output_type = connection.data_types[self.output_field.get_internal_type()] - lhs_mql = process_lhs(self, compiler, connection, as_path=False)[0] + lhs_mql = process_lhs(self, compiler, connection)[0] if max_length := self.output_field.max_length: lhs_mql = {"$substrCP": [lhs_mql, 0, max_length]} # Skip the conversion for "object" as it doesn't need to be transformed for @@ -81,22 +81,22 @@ def cast(self, compiler, connection): def concat(self, compiler, connection): - return self.get_source_expressions()[0].as_mql(compiler, connection, as_path=False) + return self.get_source_expressions()[0].as_mql(compiler, connection) def concat_pair(self, compiler, connection): # null on either side results in null for expression, wrap with coalesce. coalesced = self.coalesce() - return super(ConcatPair, coalesced).as_mql_expr(compiler, connection) + return super(ConcatPair, coalesced).as_mql(compiler, connection) def cot(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection, as_path=False) + lhs_mql = process_lhs(self, compiler, connection) return {"$divide": [1, {"$tan": lhs_mql}]} -def extract(self, compiler, connection, as_path=False): - lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) +def extract(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) operator = EXTRACT_OPERATORS.get(self.lookup_name) if operator is None: raise NotSupportedError(f"{self.__class__.__name__} is not supported.") @@ -105,22 +105,8 @@ def extract(self, compiler, connection, as_path=False): return {f"${operator}": lhs_mql} -def func(self, compiler, connection, as_path=False): - lhs_mql = process_lhs(self, compiler, connection, as_path=False) - if self.function is None: - raise NotSupportedError(f"{self} may need an as_mql() method.") - operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) - if as_path: - return {"$expr": {f"${operator}": lhs_mql}} - return {f"${operator}": lhs_mql} - - -def func_path(self, compiler, connection): # noqa: ARG001 - raise NotSupportedError(f"{self} may need an as_mql_path() method.") - - -def func_expr(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection, as_path=False) +def func(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) if self.function is None: raise NotSupportedError(f"{self} may need an as_mql() method.") operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) @@ -291,8 +277,8 @@ def register_functions(): ConcatPair.as_mql_expr = concat_pair Cot.as_mql_expr = cot Extract.as_mql_expr = extract - Func.as_mql_path = func_path - Func.as_mql_expr = func_expr + Func.as_mql_expr = func + Func.can_use_path = False JSONArray.as_mql_expr = process_lhs Left.as_mql_expr = left Length.as_mql_expr = length @@ -312,4 +298,3 @@ def register_functions(): TruncDate.as_mql_expr = trunc_date TruncTime.as_mql_expr = trunc_time Upper.as_mql_expr = preserve_null("toUpper") - Func.can_use_path = False diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index e5154fad4..eaad34691 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -12,18 +12,18 @@ from .query_utils import is_constant_value, process_lhs, process_rhs -def builtin_lookup_path(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - value = process_rhs(self, compiler, connection, as_path=True) - return connection.mongo_match_operators[self.lookup_name](lhs_mql, value) - - def builtin_lookup_expr(self, compiler, connection): value = process_rhs(self, compiler, connection, as_path=False) lhs_mql = process_lhs(self, compiler, connection, as_path=False) return connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) +def builtin_lookup_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return connection.mongo_match_operators[self.lookup_name](lhs_mql, value) + + _field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter @@ -85,18 +85,18 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) ] -def is_null_path(self, compiler, connection): +def is_null_expr(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") - lhs_mql = process_lhs(self, compiler, connection, as_path=True) - return connection.mongo_match_operators["isnull"](lhs_mql, self.rhs) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + return connection.mongo_expr_operators["isnull"](lhs_mql, self.rhs) -def is_null_expr(self, compiler, connection): +def is_null_path(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") - lhs_mql = process_lhs(self, compiler, connection, as_path=False) - return connection.mongo_expr_operators["isnull"](lhs_mql, self.rhs) + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return connection.mongo_match_operators["isnull"](lhs_mql, self.rhs) # from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4 @@ -146,9 +146,8 @@ def can_use_path(self): def register_lookups(): - Lookup.can_use_path = can_use_path - BuiltinLookup.as_mql_path = builtin_lookup_path BuiltinLookup.as_mql_expr = builtin_lookup_expr + BuiltinLookup.as_mql_path = builtin_lookup_path FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = ( field_resolve_expression_parameter ) @@ -157,6 +156,7 @@ def register_lookups(): In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline IsNull.as_mql_path = is_null_path IsNull.as_mql_expr = is_null_expr + Lookup.can_use_path = can_use_path PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value # Patching the main method, it is not supported yet. UUIDTextMixin.as_mql = uuid_text_mixin diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index ef3312cf1..e0c08e6a4 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -336,12 +336,12 @@ def where_node(self, compiler, connection, as_path=False): return mql -def nothing_node(self, compiler, connection, as_path=None): # noqa: ARG001 +def nothing_node(self, compiler, connection): return self.as_sql(compiler, connection) def register_nodes(): ExtraWhere.as_mql = extra_where Join.as_mql = join - NothingNode.as_mql = nothing_node + NothingNode.as_mql_expr = nothing_node WhereNode.as_mql = where_node From eeda4c86cee4fde63cb0c59f6143de3104c8b169 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 4 Oct 2025 00:09:59 -0300 Subject: [PATCH 19/23] Fix nothing node. --- django_mongodb_backend/lookups.py | 4 ++-- django_mongodb_backend/query.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index eaad34691..56704fc63 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -151,11 +151,11 @@ def register_lookups(): FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = ( field_resolve_expression_parameter ) - In.as_mql_path = RelatedIn.as_mql_path = wrap_in(builtin_lookup_path) In.as_mql_expr = RelatedIn.as_mql_expr = wrap_in(builtin_lookup_expr) + In.as_mql_path = RelatedIn.as_mql_path = wrap_in(builtin_lookup_path) In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline - IsNull.as_mql_path = is_null_path IsNull.as_mql_expr = is_null_expr + IsNull.as_mql_path = is_null_path Lookup.can_use_path = can_use_path PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value # Patching the main method, it is not supported yet. diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index e0c08e6a4..ef3312cf1 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -336,12 +336,12 @@ def where_node(self, compiler, connection, as_path=False): return mql -def nothing_node(self, compiler, connection): +def nothing_node(self, compiler, connection, as_path=None): # noqa: ARG001 return self.as_sql(compiler, connection) def register_nodes(): ExtraWhere.as_mql = extra_where Join.as_mql = join - NothingNode.as_mql_expr = nothing_node + NothingNode.as_mql = nothing_node WhereNode.as_mql = where_node From e05d387866d490cbbfb7c7a734e0989ccd35f23c Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 4 Oct 2025 01:14:44 -0300 Subject: [PATCH 20/23] Fix recursive call. --- django_mongodb_backend/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index a01a28464..850e4f4e0 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -87,7 +87,7 @@ def concat(self, compiler, connection): def concat_pair(self, compiler, connection): # null on either side results in null for expression, wrap with coalesce. coalesced = self.coalesce() - return super(ConcatPair, coalesced).as_mql(compiler, connection) + return super(ConcatPair, coalesced).as_mql_expr(compiler, connection) def cot(self, compiler, connection): From c5e0e9ec80e2dd3d3304727106d3e2c83e8ec7b4 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 4 Oct 2025 12:34:37 -0300 Subject: [PATCH 21/23] Remove unused function. --- django_mongodb_backend/fields/json.py | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 047fc01c7..43fc502e3 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -140,30 +140,11 @@ def key_transform_exact_path(self, compiler, connection): } -def key_transform_in(self, compiler, connection, as_path=False): +def key_transform_in_expr(self, compiler, connection): """ Return MQL to check if a JSON path exists and that its values are in the set of specified values (rhs). """ - if as_path and self.can_use_path(): - return builtin_lookup_path(self, compiler, connection) - - lhs_mql = process_lhs(self, compiler, connection) - # Traverse to the root column. - previous = self.lhs - while isinstance(previous, KeyTransform): - previous = previous.lhs - root_column = previous.as_mql(compiler, connection) - value = process_rhs(self, compiler, connection) - # Construct the expression to check if lhs_mql values are in rhs values. - expr = connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) - expr = {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} - if as_path: - return {"$expr": expr} - return expr - - -def key_transform_in_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) # Traverse to the root column. previous = self.lhs From b194d316c3270777952a1732fb091fc8b1e117df Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 4 Oct 2025 13:34:36 -0300 Subject: [PATCH 22/23] revert unneded as_path=False in project fields. --- django_mongodb_backend/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 8d3a3f3ba..e5973e514 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -714,9 +714,9 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False value = ( False if empty_result_set_value is NotImplemented else empty_result_set_value ) - fields[collection][name] = Value(value).as_mql(self, self.connection, as_path=False) + fields[collection][name] = Value(value).as_mql(self, self.connection) except FullResultSet: - fields[collection][name] = Value(True).as_mql(self, self.connection, as_path=False) + fields[collection][name] = Value(True).as_mql(self, self.connection) # Annotations (stored in None) and the main collection's fields # should appear in the top-level of the fields dict. fields.update(fields.pop(None, {})) From 8c118e68756a433180284609e1488b6a57a0e920 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 4 Oct 2025 16:26:02 -0300 Subject: [PATCH 23/23] Simplify handlers. --- django_mongodb_backend/fields/json.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 43fc502e3..9247fead2 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -126,10 +126,6 @@ def key_transform(self, compiler, connection, as_path=False): return build_json_mql_path(lhs_mql, key_transforms, as_path=as_path) -def key_transform_exact_expr(self, compiler, connection): - return builtin_lookup_expr(self, compiler, connection) - - def key_transform_exact_path(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection, as_path=True) return { @@ -157,10 +153,6 @@ def key_transform_in_expr(self, compiler, connection): return {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} -def key_transform_in_path(self, compiler, connection): - return builtin_lookup_path(self, compiler, connection) - - def key_transform_is_null_expr(self, compiler, connection): """ Return MQL to check the nullability of a key. @@ -201,10 +193,6 @@ def key_transform_numeric_lookup_mixin_expr(self, compiler, connection): return {"$and": [expr, not_missing_or_null]} -def key_transform_numeric_lookup_mixin_path(self, compiler, connection): - return builtin_lookup_path(self, compiler, connection) - - @property def keytransform_is_simple_column(self): previous = self @@ -220,20 +208,17 @@ def register_json_field(): DataContains.as_mql = data_contains HasAnyKeys.mongo_operator = "$or" HasKey.mongo_operator = None - HasKeyLookup.as_mql_path = partialmethod(has_key_lookup, as_path=True) HasKeyLookup.as_mql_expr = partialmethod(has_key_lookup, as_path=False) + HasKeyLookup.as_mql_path = partialmethod(has_key_lookup, as_path=True) HasKeyLookup.can_use_path = has_key_check_simple_expression HasKeys.mongo_operator = "$and" JSONExact.process_rhs = json_exact_process_rhs - KeyTransform.is_simple_column = keytransform_is_simple_column - KeyTransform.can_use_path = keytransform_is_simple_column - KeyTransform.as_mql_path = partialmethod(key_transform, as_path=True) KeyTransform.as_mql_expr = partialmethod(key_transform, as_path=False) - KeyTransformExact.as_mql_expr = key_transform_exact_expr + KeyTransform.as_mql_path = partialmethod(key_transform, as_path=True) + KeyTransform.can_use_path = keytransform_is_simple_column + KeyTransform.is_simple_column = keytransform_is_simple_column KeyTransformExact.as_mql_path = key_transform_exact_path - KeyTransformIn.as_mql_path = key_transform_in_path KeyTransformIn.as_mql_expr = key_transform_in_expr - KeyTransformIsNull.as_mql_path = key_transform_is_null_path KeyTransformIsNull.as_mql_expr = key_transform_is_null_expr - KeyTransformNumericLookupMixin.as_mql_path = key_transform_numeric_lookup_mixin_path + KeyTransformIsNull.as_mql_path = key_transform_is_null_path KeyTransformNumericLookupMixin.as_mql_expr = key_transform_numeric_lookup_mixin_expr