diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index 31f4b29ba..2d1dd6afe 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -64,16 +64,16 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co return {"$add": [{"$size": lhs_mql}, exits_null]} -def stddev_variance(self, compiler, connection, **extra_context): +def stddev_variance(self, compiler, connection, **extra_context): # noqa: ARG001 if self.function.endswith("_SAMP"): operator = "stdDevSamp" elif self.function.endswith("_POP"): operator = "stdDevPop" - return aggregate(self, compiler, connection, operator=operator, **extra_context) + return aggregate(self, compiler, connection, operator=operator) def register_aggregates(): - Aggregate.as_mql = aggregate - Count.as_mql = count - StdDev.as_mql = stddev_variance - Variance.as_mql = stddev_variance + Aggregate.as_mql_expr = aggregate + Count.as_mql_expr = count + StdDev.as_mql_expr = stddev_variance + Variance.as_mql_expr = stddev_variance diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index f751c27fa..c8c15ebc1 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -2,7 +2,8 @@ import logging import os -from django.core.exceptions import ImproperlyConfigured +from bson import Decimal128 +from django.core.exceptions import EmptyResultSet, FullResultSet, ImproperlyConfigured from django.db import DEFAULT_DB_ALIAS from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.utils import debug_transaction @@ -20,7 +21,7 @@ from .features import DatabaseFeatures from .introspection import DatabaseIntrospection from .operations import DatabaseOperations -from .query_utils import regex_match +from .query_utils import regex_expr, regex_match from .schema import DatabaseSchemaEditor from .utils import OperationDebugWrapper from .validation import DatabaseValidation @@ -108,7 +109,12 @@ def _isnull_operator(a, b): } return is_null if b else {"$not": is_null} - mongo_operators = { + def _isnull_operator_match(a, b): + if b: + return {"$or": [{a: {"$exists": False}}, {a: None}]} + return {"$and": [{a: {"$exists": True}}, {a: {"$ne": None}}]} + + mongo_expr_operators = { "exact": lambda a, b: {"$eq": [a, b]}, "gt": lambda a, b: {"$gt": [a, b]}, "gte": lambda a, b: {"$gte": [a, b]}, @@ -118,7 +124,7 @@ def _isnull_operator(a, b): "lte": lambda a, b: { "$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator(a, False)] }, - "in": lambda a, b: {"$in": [a, b]}, + "in": lambda a, b: {"$in": (a, b)}, "isnull": _isnull_operator, "range": lambda a, b: { "$and": [ @@ -126,11 +132,56 @@ def _isnull_operator(a, b): {"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]}, ] }, - "iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True), - "startswith": lambda a, b: regex_match(a, ("^", b)), - "istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True), - "endswith": lambda a, b: regex_match(a, (b, {"$literal": "$"})), - "iendswith": lambda a, b: regex_match(a, (b, {"$literal": "$"}), insensitive=True), + "iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True), + "startswith": lambda a, b: regex_expr(a, ("^", b)), + "istartswith": lambda a, b: regex_expr(a, ("^", b), insensitive=True), + "endswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"})), + "iendswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"}), insensitive=True), + "contains": lambda a, b: regex_expr(a, b), + "icontains": lambda a, b: regex_expr(a, b, insensitive=True), + "regex": lambda a, b: regex_expr(a, b), + "iregex": lambda a, b: regex_expr(a, b, insensitive=True), + } + + def range_match(a, b): + conditions = [] + start, end = b + if start is not None: + conditions.append({a: {"$gte": b[0]}}) + if end is not None: + conditions.append({a: {"$lte": b[1]}}) + if start is not None and end is not None: + if isinstance(start, Decimal128): + start = start.to_decimal() + if isinstance(end, Decimal128): + end = end.to_decimal() + if start > end: + raise EmptyResultSet + if not conditions: + raise FullResultSet + return {"$and": conditions} + + # match, path, find? don't know which name use. + mongo_match_operators = { + "exact": lambda a, b: {a: b}, + "gt": lambda a, b: {a: {"$gt": b}}, + "gte": lambda a, b: {a: {"$gte": b}}, + # MongoDB considers null less than zero. Exclude null values to match + # SQL behavior. + "lt": lambda a, b: { + "$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator_match(a, False)] + }, + "lte": lambda a, b: { + "$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)] + }, + "in": lambda a, b: {a: {"$in": tuple(b)}}, + "isnull": _isnull_operator_match, + "range": range_match, + "iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True), + "startswith": lambda a, b: regex_match(a, f"^{b}"), + "istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True), + "endswith": lambda a, b: regex_match(a, f"{b}$"), + "iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True), "contains": lambda a, b: regex_match(a, b), "icontains": lambda a, b: regex_match(a, b, insensitive=True), "regex": lambda a, b: regex_match(a, b), diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 628a91e84..66d715258 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -327,14 +327,14 @@ def pre_sql_setup(self, with_col_aliases=False): pipeline = self._build_aggregation_pipeline(ids, group) if self.having: having = self.having.replace_expressions(all_replacements).as_mql( - self, self.connection + self, self.connection, as_path=True ) # Add HAVING subqueries. for query in self.subqueries or (): pipeline.extend(query.get_pipeline()) # Remove the added subqueries. self.subqueries = [] - pipeline.append({"$match": {"$expr": having}}) + pipeline.append({"$match": having}) self.aggregation_pipeline = pipeline self.annotations = { target: expr.replace_expressions(all_replacements) @@ -481,11 +481,11 @@ def build_query(self, columns=None): query.lookup_pipeline = self.get_lookup_pipeline() where = self.get_where() try: - expr = where.as_mql(self, self.connection) if where else {} + expr = where.as_mql(self, self.connection, as_path=True) if where else {} except FullResultSet: query.match_mql = {} else: - query.match_mql = {"$expr": expr} + query.match_mql = expr if extra_fields: query.extra_fields = self.get_project_fields(extra_fields, force_expression=True) query.subqueries = self.subqueries @@ -714,9 +714,9 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False value = ( False if empty_result_set_value is NotImplemented else empty_result_set_value ) - fields[collection][name] = Value(value).as_mql(self, self.connection) + fields[collection][name] = Value(value).as_mql(self, self.connection, as_path=False) except FullResultSet: - fields[collection][name] = Value(True).as_mql(self, self.connection) + fields[collection][name] = Value(True).as_mql(self, self.connection, as_path=False) # Annotations (stored in None) and the main collection's fields # should appear in the top-level of the fields dict. fields.update(fields.pop(None, {})) diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 0bc939350..62ee5a67d 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -6,6 +6,7 @@ from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import NotSupportedError from django.db.models.expressions import ( + BaseExpression, Case, Col, ColPairs, @@ -28,12 +29,20 @@ from ..query_utils import process_lhs +def base_expression(self, compiler, connection, as_path=False, **extra): + if as_path and hasattr(self, "as_mql_path") and getattr(self, "can_use_path", False): + return self.as_mql_path(compiler, connection, **extra) + + expr = self.as_mql_expr(compiler, connection, **extra) + return {"$expr": expr} if as_path else expr + + def case(self, compiler, connection): case_parts = [] for case in self.cases: case_mql = {} try: - case_mql["case"] = case.as_mql(compiler, connection) + case_mql["case"] = case.as_mql(compiler, connection, as_path=False) except EmptyResultSet: continue except FullResultSet: @@ -76,34 +85,34 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001 return f"{prefix}{self.target.column}" -def col_pairs(self, compiler, connection): +def col_pairs(self, compiler, connection, as_path=False): cols = self.get_cols() if len(cols) > 1: raise NotSupportedError("ColPairs is not supported.") - return cols[0].as_mql(compiler, connection) + return cols[0].as_mql(compiler, connection, as_path=as_path) -def combined_expression(self, compiler, connection): +def combined_expression(self, compiler, connection, as_path=False): expressions = [ - self.lhs.as_mql(compiler, connection), - self.rhs.as_mql(compiler, connection), + self.lhs.as_mql(compiler, connection, as_path=as_path), + self.rhs.as_mql(compiler, connection, as_path=as_path), ] return connection.ops.combine_expression(self.connector, expressions) -def expression_wrapper(self, compiler, connection): - return self.expression.as_mql(compiler, connection) +def expression_wrapper_expr(self, compiler, connection): + return self.expression.as_mql(compiler, connection, as_path=False) -def negated_expression(self, compiler, connection): - return {"$not": expression_wrapper(self, compiler, connection)} +def negated_expression_expr(self, compiler, connection): + return {"$not": expression_wrapper_expr(self, compiler, connection)} def order_by(self, compiler, connection): return self.expression.as_mql(compiler, connection) -def query(self, compiler, connection, get_wrapping_pipeline=None): +def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False): subquery_compiler = self.get_compiler(connection=connection) subquery_compiler.pre_sql_setup(with_col_aliases=False) field_name, expr = subquery_compiler.columns[0] @@ -145,6 +154,8 @@ def query(self, compiler, connection, get_wrapping_pipeline=None): # Erase project_fields since the required value is projected above. subquery.project_fields = None compiler.subqueries.append(subquery) + if as_path: + return f"{table_output}.{field_name}" return f"${table_output}.{field_name}" @@ -152,7 +163,7 @@ def raw_sql(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("RawSQL is not supported on MongoDB.") -def ref(self, compiler, connection): # noqa: ARG001 +def ref(self, compiler, connection, as_path=False): # noqa: ARG001 prefix = ( f"{self.source.alias}." if isinstance(self.source, Col) and self.source.alias != compiler.collection_name @@ -162,15 +173,24 @@ def ref(self, compiler, connection): # noqa: ARG001 refs, _ = compiler.columns[self.ordinal - 1] else: refs = self.refs - return f"${prefix}{refs}" + if not as_path: + prefix = f"${prefix}" + return f"{prefix}{refs}" + +@property +def ref_is_simple_column(self): + return self.source.is_simple_column -def star(self, compiler, connection): # noqa: ARG001 + +def star(self, compiler, connection, as_path=False): # noqa: ARG001 return {"$literal": True} def subquery(self, compiler, connection, get_wrapping_pipeline=None): - return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) + return self.query.as_mql( + compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False + ) def exists(self, compiler, connection, get_wrapping_pipeline=None): @@ -178,16 +198,16 @@ def exists(self, compiler, connection, get_wrapping_pipeline=None): lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) except EmptyResultSet: return Value(False).as_mql(compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, False) + return connection.mongo_expr_operators["isnull"](lhs_mql, False) -def when(self, compiler, connection): - return self.condition.as_mql(compiler, connection) +def when(self, compiler, connection, as_path=False): + return self.condition.as_mql(compiler, connection, as_path=as_path) -def value(self, compiler, connection): # noqa: ARG001 +def value(self, compiler, connection, as_path=False): # noqa: ARG001 value = self.value - if isinstance(value, (list, int)): + if isinstance(value, (list, int)) and not as_path: # Wrap lists & numbers in $literal to prevent ambiguity when Value # appears in $project. return {"$literal": value} @@ -210,20 +230,24 @@ def value(self, compiler, connection): # noqa: ARG001 def register_expressions(): - Case.as_mql = case + BaseExpression.as_mql = base_expression + BaseExpression.is_simple_column = False + Case.as_mql_expr = case Col.as_mql = col + Col.is_simple_column = True ColPairs.as_mql = col_pairs - CombinedExpression.as_mql = combined_expression - Exists.as_mql = exists + CombinedExpression.as_mql_expr = combined_expression + Exists.as_mql_expr = exists ExpressionList.as_mql = process_lhs - ExpressionWrapper.as_mql = expression_wrapper - NegatedExpression.as_mql = negated_expression - OrderBy.as_mql = order_by + ExpressionWrapper.as_mql_expr = expression_wrapper_expr + NegatedExpression.as_mql_expr = negated_expression_expr + OrderBy.as_mql_expr = order_by Query.as_mql = query RawSQL.as_mql = raw_sql Ref.as_mql = ref + Ref.is_simple_column = ref_is_simple_column ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql Star.as_mql = star - Subquery.as_mql = subquery + Subquery.as_mql_expr = subquery When.as_mql = when Value.as_mql = value diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 3783f5943..aba5e0cfa 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -933,11 +933,16 @@ def __str__(self): def __repr__(self): return f"SearchText({self.lhs}, {self.rhs})" - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + value = process_rhs(self, compiler, connection, as_path=False) return {"$gte": [lhs_mql, value]} + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return {lhs_mql: {"$gte": value}} + CharField.register_lookup(SearchTextLookup) TextField.register_lookup(SearchTextLookup) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 7e29c5003..18a048bf6 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -90,9 +90,6 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures): "auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key", # GenericRelation.value_to_string() assumes integer pk. "contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string", - # icontains doesn't work on ArrayField: - # Unsupported conversion from array to string in $convert - "model_fields_.test_arrayfield.QueryingTests.test_icontains", # ArrayField's contained_by lookup crashes with Exists: "both operands " # of $setIsSubset must be arrays. Second argument is of type: null" # https://jira.mongodb.org/browse/SERVER-99186 diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index da64ee2e8..da83d7a66 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -230,8 +230,11 @@ def formfield(self, **kwargs): class Array(Func): - def as_mql(self, compiler, connection): - return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()] + def as_mql_expr(self, compiler, connection): + return [ + expr.as_mql(compiler, connection, as_path=False) + for expr in self.get_source_expressions() + ] class ArrayRHSMixin: @@ -251,9 +254,9 @@ def __init__(self, lhs, rhs): class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contains" - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + value = process_rhs(self, compiler, connection, as_path=False) return { "$and": [ {"$ne": [lhs_mql, None]}, @@ -262,14 +265,21 @@ def as_mql(self, compiler, connection): ] } + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + if value is None: + return False + return {lhs_mql: {"$all": value}} + @ArrayField.register_lookup class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contained_by" - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + value = process_rhs(self, compiler, connection, as_path=False) return { "$and": [ {"$ne": [lhs_mql, None]}, @@ -323,21 +333,29 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) }, ] - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + value = process_rhs(self, compiler, connection, as_path=False) return { - "$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}] + "$and": [ + {"$ne": [lhs_mql, None]}, + {"$size": {"$setIntersection": [value, lhs_mql]}}, + ] } + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return {lhs_mql: {"$in": value}} + @ArrayField.register_lookup class ArrayLenTransform(Transform): lookup_name = "len" output_field = IntegerField() - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} @@ -363,10 +381,22 @@ def __init__(self, index, base_field, *args, **kwargs): self.index = index self.base_field = base_field - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + @property + def can_use_path(self): + return self.is_simple_column + + @property + def is_simple_column(self): + return self.lhs.is_simple_column + + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) return {"$arrayElemAt": [lhs_mql, self.index]} + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return f"{lhs_mql}.{self.index}" + @property def output_field(self): return self.base_field @@ -387,7 +417,7 @@ def __init__(self, start, end, *args, **kwargs): self.start = start self.end = end - def as_mql(self, compiler, connection): + def as_mql_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) return {"$slice": [lhs_mql, self.start, self.end]} diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 951632363..af769468c 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -5,6 +5,7 @@ from django.db import models from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Transform +from django.utils.functional import cached_property from .. import forms @@ -165,6 +166,19 @@ def __init__(self, key_name, ref_field, *args, **kwargs): def get_lookup(self, name): return self.ref_field.get_lookup(name) + @property + def can_use_path(self): + return self.is_simple_column + + @cached_property + def is_simple_column(self): + previous = self + while isinstance(previous, KeyTransform): + if not previous.key_name.isalnum(): + return False + previous = previous.lhs + return previous.is_simple_column + def get_transform(self, name): """ Validate that `name` is either a field of an embedded model or a @@ -184,21 +198,27 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection, as_path=False): + def _get_target_path(self): previous = self key_transforms = [] while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs - if as_path: - mql = previous.as_mql(compiler, connection, as_path=True) - mql_path = ".".join(key_transforms) - return f"{mql}.{mql_path}" - mql = previous.as_mql(compiler, connection) + return key_transforms, previous + + def as_mql_expr(self, compiler, connection): + key_transforms, parent_field = self._get_target_path() + mql = parent_field.as_mql(compiler, connection) for key in key_transforms: mql = {"$getField": {"input": mql, "field": key}} return mql + def as_mql_path(self, compiler, connection): + key_transforms, parent_field = self._get_target_path() + mql = parent_field.as_mql(compiler, connection, as_path=True) + mql_path = ".".join(key_transforms) + return f"{mql}.{mql_path}" + @property def output_field(self): return self.ref_field diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index d04b99db1..86b118f24 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -5,8 +5,10 @@ from django.db.models.expressions import Col from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Lookup, Transform +from django.utils.functional import cached_property from .. import forms +from ..lookups import builtin_lookup_path from ..query_utils import process_lhs, process_rhs from . import EmbeddedModelField from .array import ArrayField, ArrayLenTransform @@ -75,7 +77,7 @@ def _get_lookup(self, lookup_name): return lookup class EmbeddedModelArrayFieldLookups(Lookup): - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): raise ValueError( "Lookups aren't supported on EmbeddedModelArrayField. " "Try querying one of its embedded fields instead." @@ -114,7 +116,7 @@ def get_lookup(self, name): class EmbeddedModelArrayFieldBuiltinLookup(Lookup): - def process_rhs(self, compiler, connection): + def process_rhs(self, compiler, connection, as_path=False): value = self.rhs if not self.get_db_prep_lookup_value_is_iterable: value = [value] @@ -128,18 +130,21 @@ def process_rhs(self, compiler, connection): for v in value ] - def as_mql(self, compiler, connection): + def as_mql_expr(self, compiler, connection): # Querying a subfield within the array elements (via nested # KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over # the array and applying $in on the subfield. lhs_mql = process_lhs(self, compiler, connection) inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"] values = process_rhs(self, compiler, connection) - lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name]( + lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_expr_operators[self.lookup_name]( inner_lhs_mql, values ) return {"$anyElementTrue": lhs_mql} + def as_mql_path(self, compiler, connection): + return builtin_lookup_path(self, compiler, connection) + @_EmbeddedModelArrayOutputField.register_lookup class EmbeddedModelArrayFieldIn(EmbeddedModelArrayFieldBuiltinLookup, lookups.In): @@ -243,6 +248,19 @@ def __call__(self, this, *args, **kwargs): self._lhs = self._sub_transform(self._lhs, *args, **kwargs) return self + @property + def can_use_path(self): + return self.is_simple_column + + @cached_property + def is_simple_column(self): + previous = self + while isinstance(previous, KeyTransform): + if not previous.key_name.isalnum(): + return False + previous = previous.lhs + return previous.is_simple_column and self._lhs.is_simple_column + def get_lookup(self, name): return self.output_field.get_lookup(name) @@ -275,7 +293,7 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection): + def as_mql_expr(self, compiler, connection): inner_lhs_mql = self._lhs.as_mql(compiler, connection) lhs_mql = process_lhs(self, compiler, connection) return { @@ -291,6 +309,11 @@ def as_mql(self, compiler, connection): ] } + def as_mql_path(self, compiler, connection): + inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True).removeprefix("$item.") + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return f"{lhs_mql}.{inner_lhs_mql}" + @property def output_field(self): return _EmbeddedModelArrayOutputField(self._lhs.output_field) diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index 1a7ecb615..bb979875a 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -1,3 +1,6 @@ +from functools import partialmethod +from itertools import chain + from django.db import NotSupportedError from django.db.models.fields.json import ( ContainedBy, @@ -8,17 +11,20 @@ HasKeys, JSONExact, KeyTransform, + KeyTransformExact, KeyTransformIn, KeyTransformIsNull, KeyTransformNumericLookupMixin, ) -from ..lookups import builtin_lookup +from ..lookups import builtin_lookup_expr, builtin_lookup_path from ..query_utils import process_lhs, process_rhs -def build_json_mql_path(lhs, key_transforms): +def build_json_mql_path(lhs, key_transforms, as_path=False): # Build the MQL path using the collected key transforms. + if as_path: + return ".".join(chain([lhs], key_transforms)) result = lhs for key in key_transforms: get_field = {"$getField": {"input": result, "field": key}} @@ -37,16 +43,21 @@ def build_json_mql_path(lhs, key_transforms): return result -def contained_by(self, compiler, connection): # noqa: ARG001 +def contained_by(self, compiler, connection, as_path=False): # noqa: ARG001 raise NotSupportedError("contained_by lookup is not supported on this database backend.") -def data_contains(self, compiler, connection): # noqa: ARG001 +def data_contains(self, compiler, connection, as_path=False): # noqa: ARG001 raise NotSupportedError("contains lookup is not supported on this database backend.") -def _has_key_predicate(path, root_column, negated=False): +def _has_key_predicate(path, root_column=None, negated=False, as_path=False): """Return MQL to check for the existence of `path`.""" + if as_path: + # if not negated: + return {path: {"$exists": not negated}} + # return {"$and": [{path: {"$exists": True}}, {path: {"$ne": None}}]} + # return {"$or": [{path: {"$exists": False}}, {path: None}]} result = { "$and": [ # The path must exist (i.e. not be "missing"). @@ -61,21 +72,27 @@ def _has_key_predicate(path, root_column, negated=False): return result -def has_key_lookup(self, compiler, connection): +@property +def has_key_check_simple_expression(self): + rhs = [self.rhs] if not isinstance(self.rhs, (list, tuple)) else self.rhs + return self.is_simple_column and all(key.isalnum() for key in rhs) + + +def has_key_lookup(self, compiler, connection, as_path=False): """Return MQL to check for the existence of a key.""" rhs = self.rhs - lhs = process_lhs(self, compiler, connection) if not isinstance(rhs, (list, tuple)): rhs = [rhs] + lhs = process_lhs(self, compiler, connection, as_path=as_path) paths = [] # Transform any "raw" keys into KeyTransforms to allow consistent handling # in the code that follows. for key in rhs: rhs_json_path = key if isinstance(key, KeyTransform) else KeyTransform(key, self.lhs) - paths.append(rhs_json_path.as_mql(compiler, connection)) + paths.append(rhs_json_path.as_mql(compiler, connection, as_path=as_path)) keys = [] for path in paths: - keys.append(_has_key_predicate(path, lhs)) + keys.append(_has_key_predicate(path, lhs, as_path=as_path)) if self.mongo_operator is None: return keys[0] return {self.mongo_operator: keys} @@ -93,7 +110,7 @@ def json_exact_process_rhs(self, compiler, connection): ) -def key_transform(self, compiler, connection): +def key_transform(self, compiler, connection, as_path=False): """ Return MQL for this KeyTransform (JSON path). @@ -108,15 +125,48 @@ def key_transform(self, compiler, connection): while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs - lhs_mql = previous.as_mql(compiler, connection) - return build_json_mql_path(lhs_mql, key_transforms) + lhs_mql = previous.as_mql(compiler, connection, as_path=as_path) + return build_json_mql_path(lhs_mql, key_transforms, as_path=as_path) + + +def key_transform_exact_expr(self, compiler, connection): + return builtin_lookup_expr(self, compiler, connection) + + +def key_transform_exact_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return { + "$and": [ + builtin_lookup_path(self, compiler, connection), + _has_key_predicate(lhs_mql, None, as_path=True), + ] + } -def key_transform_in(self, compiler, connection): +def key_transform_in(self, compiler, connection, as_path=False): """ Return MQL to check if a JSON path exists and that its values are in the set of specified values (rhs). """ + if as_path and self.can_use_path(): + return builtin_lookup_path(self, compiler, connection) + + lhs_mql = process_lhs(self, compiler, connection) + # Traverse to the root column. + previous = self.lhs + while isinstance(previous, KeyTransform): + previous = previous.lhs + root_column = previous.as_mql(compiler, connection) + value = process_rhs(self, compiler, connection) + # Construct the expression to check if lhs_mql values are in rhs values. + expr = connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) + expr = {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} + if as_path: + return {"$expr": expr} + return expr + + +def key_transform_in_expr(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) # Traverse to the root column. previous = self.lhs @@ -125,11 +175,15 @@ def key_transform_in(self, compiler, connection): root_column = previous.as_mql(compiler, connection) value = process_rhs(self, compiler, connection) # Construct the expression to check if lhs_mql values are in rhs values. - expr = connection.mongo_operators[self.lookup_name](lhs_mql, value) + expr = connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) return {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} -def key_transform_is_null(self, compiler, connection): +def key_transform_in_path(self, compiler, connection): + return builtin_lookup_path(self, compiler, connection) + + +def key_transform_is_null_expr(self, compiler, connection): """ Return MQL to check the nullability of a key. @@ -139,37 +193,69 @@ def key_transform_is_null(self, compiler, connection): Reference: https://code.djangoproject.com/ticket/32252 """ - lhs_mql = process_lhs(self, compiler, connection) - rhs_mql = process_rhs(self, compiler, connection) - # Get the root column. previous = self.lhs while isinstance(previous, KeyTransform): previous = previous.lhs root_column = previous.as_mql(compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + rhs_mql = process_rhs(self, compiler, connection) return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) -def key_transform_numeric_lookup_mixin(self, compiler, connection): +def key_transform_is_null_path(self, compiler, connection): + """ + Return MQL to check the nullability of a key using the operator $exists. + """ + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + rhs_mql = process_rhs(self, compiler, connection, as_path=True) + return _has_key_predicate(lhs_mql, None, negated=rhs_mql, as_path=True) + + +def key_transform_numeric_lookup_mixin_expr(self, compiler, connection): """ Return MQL to check if the field exists (i.e., is not "missing" or "null") and that the field matches the given numeric lookup expression. """ - expr = builtin_lookup(self, compiler, connection) - lhs = process_lhs(self, compiler, connection) + lhs = process_lhs(self, compiler, connection, as_path=False) + expr = builtin_lookup_expr(self, compiler, connection) # Check if the type of lhs is not "missing" or "null". not_missing_or_null = {"$not": {"$in": [{"$type": lhs}, ["missing", "null"]]}} return {"$and": [expr, not_missing_or_null]} +def key_transform_numeric_lookup_mixin_path(self, compiler, connection): + return builtin_lookup_path(self, compiler, connection) + + +@property +def keytransform_is_simple_column(self): + previous = self + while isinstance(previous, KeyTransform): + if not previous.key_name.isalnum(): + return False + previous = previous.lhs + return previous.is_simple_column + + def register_json_field(): ContainedBy.as_mql = contained_by DataContains.as_mql = data_contains HasAnyKeys.mongo_operator = "$or" HasKey.mongo_operator = None - HasKeyLookup.as_mql = has_key_lookup + HasKeyLookup.as_mql_path = partialmethod(has_key_lookup, as_path=True) + HasKeyLookup.as_mql_expr = partialmethod(has_key_lookup, as_path=False) + HasKeyLookup.can_use_path = has_key_check_simple_expression HasKeys.mongo_operator = "$and" JSONExact.process_rhs = json_exact_process_rhs - KeyTransform.as_mql = key_transform - KeyTransformIn.as_mql = key_transform_in - KeyTransformIsNull.as_mql = key_transform_is_null - KeyTransformNumericLookupMixin.as_mql = key_transform_numeric_lookup_mixin + KeyTransform.is_simple_column = keytransform_is_simple_column + KeyTransform.can_use_path = keytransform_is_simple_column + KeyTransform.as_mql_path = partialmethod(key_transform, as_path=True) + KeyTransform.as_mql_expr = partialmethod(key_transform, as_path=False) + KeyTransformExact.as_mql_expr = key_transform_exact_expr + KeyTransformExact.as_mql_path = key_transform_exact_path + KeyTransformIn.as_mql_path = key_transform_in_path + KeyTransformIn.as_mql_expr = key_transform_in_expr + KeyTransformIsNull.as_mql_path = key_transform_is_null_path + KeyTransformIsNull.as_mql_expr = key_transform_is_null_expr + KeyTransformNumericLookupMixin.as_mql_path = key_transform_numeric_lookup_mixin_path + KeyTransformNumericLookupMixin.as_mql_expr = key_transform_numeric_lookup_mixin_expr diff --git a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py index e625d2ea8..3e84bbb14 100644 --- a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py +++ b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py @@ -70,7 +70,7 @@ def _get_lookup(self, lookup_name): return lookup class EmbeddedModelArrayFieldLookups(Lookup): - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_path=False): raise ValueError( "Lookups aren't supported on PolymorphicEmbeddedModelArrayField. " "Try querying one of its embedded fields instead." diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index c45800a0a..417a427d8 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -67,7 +67,7 @@ def cast(self, compiler, connection): output_type = connection.data_types[self.output_field.get_internal_type()] - lhs_mql = process_lhs(self, compiler, connection)[0] + lhs_mql = process_lhs(self, compiler, connection, as_path=False)[0] if max_length := self.output_field.max_length: lhs_mql = {"$substrCP": [lhs_mql, 0, max_length]} # Skip the conversion for "object" as it doesn't need to be transformed for @@ -81,22 +81,22 @@ def cast(self, compiler, connection): def concat(self, compiler, connection): - return self.get_source_expressions()[0].as_mql(compiler, connection) + return self.get_source_expressions()[0].as_mql(compiler, connection, as_path=False) def concat_pair(self, compiler, connection): # null on either side results in null for expression, wrap with coalesce. coalesced = self.coalesce() - return super(ConcatPair, coalesced).as_mql(compiler, connection) + return super(ConcatPair, coalesced).as_mql_expr(compiler, connection) def cot(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) return {"$divide": [1, {"$tan": lhs_mql}]} -def extract(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) +def extract(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=as_path) operator = EXTRACT_OPERATORS.get(self.lookup_name) if operator is None: raise NotSupportedError(f"{self.__class__.__name__} is not supported.") @@ -105,8 +105,22 @@ def extract(self, compiler, connection): return {f"${operator}": lhs_mql} -def func(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) +def func(self, compiler, connection, as_path=False): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + if self.function is None: + raise NotSupportedError(f"{self} may need an as_mql() method.") + operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) + if as_path: + return {"$expr": {f"${operator}": lhs_mql}} + return {f"${operator}": lhs_mql} + + +def func_path(self, compiler, connection): # noqa: ARG001 + raise NotSupportedError(f"{self} may need an as_mql_path() method.") + + +def func_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=False) if self.function is None: raise NotSupportedError(f"{self} may need an as_mql() method.") operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) @@ -114,12 +128,12 @@ def func(self, compiler, connection): def left(self, compiler, connection): - return self.get_substr().as_mql(compiler, connection) + return self.get_substr().as_mql(compiler, connection, as_path=False) def length(self, compiler, connection): # Check for null first since $strLenCP only accepts strings. - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}} @@ -127,7 +141,7 @@ def log(self, compiler, connection): # This function is usually log(base, num) but on MongoDB it's log(num, base). clone = self.copy() clone.set_source_expressions(self.get_source_expressions()[::-1]) - return func(clone, compiler, connection) + return func(clone, compiler, connection, as_path=False) def now(self, compiler, connection): # noqa: ARG001 @@ -136,7 +150,9 @@ def now(self, compiler, connection): # noqa: ARG001 def null_if(self, compiler, connection): """Return None if expr1==expr2 else expr1.""" - expr1, expr2 = (expr.as_mql(compiler, connection) for expr in self.get_source_expressions()) + expr1, expr2 = ( + expr.as_mql(compiler, connection, as_path=False) for expr in self.get_source_expressions() + ) return {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}} @@ -144,10 +160,10 @@ def preserve_null(operator): # If the argument is null, the function should return null, not # $toLower/Upper's behavior of returning an empty string. def wrapped(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) return { "$cond": { - "if": connection.mongo_operators["isnull"](lhs_mql, True), + "if": connection.mongo_expr_operators["isnull"](lhs_mql, True), "then": None, "else": {f"${operator}": lhs_mql}, } @@ -157,18 +173,23 @@ def wrapped(self, compiler, connection): def replace(self, compiler, connection): - expression, text, replacement = process_lhs(self, compiler, connection) + expression, text, replacement = process_lhs(self, compiler, connection, as_path=False) return {"$replaceAll": {"input": expression, "find": text, "replacement": replacement}} def round_(self, compiler, connection): # Round needs its own function because it's a special case that inherits # from Transform but has two arguments. - return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]} + return { + "$round": [ + expr.as_mql(compiler, connection, as_path=False) + for expr in self.get_source_expressions() + ] + } def str_index(self, compiler, connection): - lhs = process_lhs(self, compiler, connection) + lhs = process_lhs(self, compiler, connection, as_path=False) # StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB. return {"$add": [{"$indexOfCP": lhs}, 1]} @@ -232,7 +253,7 @@ def trunc_convert_value(self, value, expression, connection): def trunc_date(self, compiler, connection): # Cast to date rather than truncate to date. - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncDate with tzinfo ({tzname}) isn't supported on MongoDB.") @@ -255,7 +276,7 @@ def trunc_time(self, compiler, connection): tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncTime with tzinfo ({tzname}) isn't supported on MongoDB.") - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) return { "$dateFromString": { "dateString": { @@ -272,28 +293,30 @@ def trunc_time(self, compiler, connection): def register_functions(): - Cast.as_mql = cast - Concat.as_mql = concat - ConcatPair.as_mql = concat_pair - Cot.as_mql = cot - Extract.as_mql = extract - Func.as_mql = func - JSONArray.as_mql = process_lhs - Left.as_mql = left - Length.as_mql = length - Log.as_mql = log - Lower.as_mql = preserve_null("toLower") - LTrim.as_mql = trim("ltrim") - Now.as_mql = now - NullIf.as_mql = null_if - Replace.as_mql = replace - Round.as_mql = round_ - RTrim.as_mql = trim("rtrim") - StrIndex.as_mql = str_index - Substr.as_mql = substr - Trim.as_mql = trim("trim") - TruncBase.as_mql = trunc + Cast.as_mql_expr = cast + Concat.as_mql_expr = concat + ConcatPair.as_mql_expr = concat_pair + Cot.as_mql_expr = cot + Extract.as_mql_expr = extract + Func.as_mql_path = func_path + Func.as_mql_expr = func_expr + JSONArray.as_mql_expr = process_lhs + Left.as_mql_expr = left + Length.as_mql_expr = length + Log.as_mql_expr = log + Lower.as_mql_expr = preserve_null("toLower") + LTrim.as_mql_expr = trim("ltrim") + Now.as_mql_expr = now + NullIf.as_mql_expr = null_if + Replace.as_mql_expr = replace + Round.as_mql_expr = round_ + RTrim.as_mql_expr = trim("rtrim") + StrIndex.as_mql_expr = str_index + Substr.as_mql_expr = substr + Trim.as_mql_expr = trim("trim") + TruncBase.as_mql_expr = trunc TruncBase.convert_value = trunc_convert_value - TruncDate.as_mql = trunc_date - TruncTime.as_mql = trunc_time - Upper.as_mql = preserve_null("toUpper") + TruncDate.as_mql_expr = trunc_date + TruncTime.as_mql_expr = trunc_time + Upper.as_mql_expr = preserve_null("toUpper") + Func.can_use_path = False diff --git a/django_mongodb_backend/gis/lookups.py b/django_mongodb_backend/gis/lookups.py index 29c2e1e96..b81eeea9c 100644 --- a/django_mongodb_backend/gis/lookups.py +++ b/django_mongodb_backend/gis/lookups.py @@ -2,7 +2,7 @@ from django.db import NotSupportedError -def gis_lookup(self, compiler, connection): # noqa: ARG001 +def gis_lookup(self, compiler, connection, as_path=False): # noqa: ARG001 raise NotSupportedError(f"MongoDB does not support the {self.lookup_name} lookup.") diff --git a/django_mongodb_backend/indexes.py b/django_mongodb_backend/indexes.py index 99c1ef5f3..f61b3524b 100644 --- a/django_mongodb_backend/indexes.py +++ b/django_mongodb_backend/indexes.py @@ -34,7 +34,7 @@ def _get_condition_mql(self, model, schema_editor): def builtin_lookup_idx(self, compiler, connection): lhs_mql = self.lhs.target.column - value = process_rhs(self, compiler, connection) + value = process_rhs(self, compiler, connection, as_path=True) try: operator = MONGO_INDEX_OPERATORS[self.lookup_name] except KeyError: diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 8dda2bab3..e5154fad4 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -4,17 +4,24 @@ BuiltinLookup, FieldGetDbPrepValueIterableMixin, IsNull, + Lookup, PatternLookup, UUIDTextMixin, ) -from .query_utils import process_lhs, process_rhs +from .query_utils import is_constant_value, process_lhs, process_rhs -def builtin_lookup(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) - return connection.mongo_operators[self.lookup_name](lhs_mql, value) +def builtin_lookup_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + value = process_rhs(self, compiler, connection, as_path=True) + return connection.mongo_match_operators[self.lookup_name](lhs_mql, value) + + +def builtin_lookup_expr(self, compiler, connection): + value = process_rhs(self, compiler, connection, as_path=False) + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + return connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) _field_resolve_expression_parameter = FieldGetDbPrepValueIterableMixin.resolve_expression_parameter @@ -33,14 +40,17 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param): return sql, sql_params -def in_(self, compiler, connection): - db_rhs = getattr(self.rhs, "_db", None) - if db_rhs is not None and db_rhs != connection.alias: - raise ValueError( - "Subqueries aren't allowed across different databases. Force " - "the inner query to be evaluated using `list(inner_query)`." - ) - return builtin_lookup(self, compiler, connection) +def wrap_in(function): + def inner(self, compiler, connection): + db_rhs = getattr(self.rhs, "_db", None) + if db_rhs is not None and db_rhs != connection.alias: + raise ValueError( + "Subqueries aren't allowed across different databases. Force " + "the inner query to be evaluated using `list(inner_query)`." + ) + return function(self, compiler, connection) + + return inner def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001 @@ -75,11 +85,18 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) ] -def is_null(self, compiler, connection): +def is_null_path(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") - lhs_mql = process_lhs(self, compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, self.rhs) + lhs_mql = process_lhs(self, compiler, connection, as_path=True) + return connection.mongo_match_operators["isnull"](lhs_mql, self.rhs) + + +def is_null_expr(self, compiler, connection): + if not isinstance(self.rhs, bool): + raise ValueError("The QuerySet value for an isnull lookup must be True or False.") + lhs_mql = process_lhs(self, compiler, connection, as_path=False) + return connection.mongo_expr_operators["isnull"](lhs_mql, self.rhs) # from https://www.pcre.org/current/doc/html/pcre2pattern.html#SEC4 @@ -121,13 +138,25 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("Pattern lookups on UUIDField are not supported.") +@property +def can_use_path(self): + simple_column = getattr(self.lhs, "is_simple_column", False) + constant_value = is_constant_value(self.rhs) + return simple_column and constant_value + + def register_lookups(): - BuiltinLookup.as_mql = builtin_lookup + Lookup.can_use_path = can_use_path + BuiltinLookup.as_mql_path = builtin_lookup_path + BuiltinLookup.as_mql_expr = builtin_lookup_expr FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = ( field_resolve_expression_parameter ) - In.as_mql = RelatedIn.as_mql = in_ + In.as_mql_path = RelatedIn.as_mql_path = wrap_in(builtin_lookup_path) + In.as_mql_expr = RelatedIn.as_mql_expr = wrap_in(builtin_lookup_expr) In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline - IsNull.as_mql = is_null + IsNull.as_mql_path = is_null_path + IsNull.as_mql_expr = is_null_expr PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value + # Patching the main method, it is not supported yet. UUIDTextMixin.as_mql = uuid_text_mixin diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index c86b8721b..ef3312cf1 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -11,8 +11,6 @@ from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError -from .query_conversion.query_optimizer import convert_expr_to_match - def wrap_database_errors(func): @wraps(func) @@ -89,7 +87,7 @@ def get_pipeline(self): for query in self.subqueries or (): pipeline.extend(query.get_pipeline()) if self.match_mql: - pipeline.extend(convert_expr_to_match(self.match_mql)) + pipeline.append({"$match": self.match_mql}) if self.aggregation_pipeline: pipeline.extend(self.aggregation_pipeline) if self.project_fields: @@ -168,6 +166,7 @@ def _get_reroot_replacements(expression): target.remote_field = col.target.remote_field column_target = Col(compiler.collection_name, target) if parent_pos is not None: + column_target.is_simple_column = False target_col = f"${parent_template}{parent_pos}" column_target.target.db_column = target_col column_target.target.set_attributes_from_name(target_col) @@ -194,7 +193,7 @@ def _get_reroot_replacements(expression): if extra: replacements = _get_reroot_replacements(extra) extra_conditions.append( - extra.replace_expressions(replacements).as_mql(compiler, connection) + extra.replace_expressions(replacements).as_mql(compiler, connection, as_path=True) ) # pushed_filter_expression is a Where expression from the outer WHERE # clause that involves fields from the joined (right-hand) table and @@ -208,9 +207,26 @@ def _get_reroot_replacements(expression): rerooted_replacement = _get_reroot_replacements(pushed_filter_expression) extra_conditions.append( pushed_filter_expression.replace_expressions(rerooted_replacement).as_mql( - compiler, connection + compiler, connection, as_path=True ) ) + + # Match the conditions: + # self.table_name.field1 = parent_table.field1 + # AND + # self.table_name.field2 = parent_table.field2 + # AND + # ... + condition = { + "$expr": { + "$and": [ + {"$eq": [f"$${parent_template}{i}", field]} for i, field in enumerate(rhs_fields) + ] + } + } + if extra_conditions: + condition = {"$and": [condition, *extra_conditions]} + lookup_pipeline = [ { "$lookup": { @@ -222,25 +238,7 @@ def _get_reroot_replacements(expression): f"{parent_template}{i}": parent_field for i, parent_field in enumerate(lhs_fields) }, - "pipeline": [ - { - # Match the conditions: - # self.table_name.field1 = parent_table.field1 - # AND - # self.table_name.field2 = parent_table.field2 - # AND - # ... - "$match": { - "$expr": { - "$and": [ - {"$eq": [f"$${parent_template}{i}", field]} - for i, field in enumerate(rhs_fields) - ] - + extra_conditions - } - } - } - ], + "pipeline": [{"$match": condition}], # Rename the output as table_alias. "as": self.table_alias, } @@ -274,7 +272,7 @@ def _get_reroot_replacements(expression): return lookup_pipeline -def where_node(self, compiler, connection): +def where_node(self, compiler, connection, as_path=False): if self.connector == AND: full_needed, empty_needed = len(self.children), 1 else: @@ -297,14 +295,16 @@ def where_node(self, compiler, connection): if len(self.children) > 2: rhs_sum = Mod(rhs_sum, 2) rhs = Exact(1, rhs_sum) - return self.__class__([lhs, rhs], AND, self.negated).as_mql(compiler, connection) + return self.__class__([lhs, rhs], AND, self.negated).as_mql( + compiler, connection, as_path=as_path + ) else: operator = "$or" children_mql = [] for child in self.children: try: - mql = child.as_mql(compiler, connection) + mql = child.as_mql(compiler, connection, as_path=as_path) except EmptyResultSet: empty_needed -= 1 except FullResultSet: @@ -331,13 +331,17 @@ def where_node(self, compiler, connection): raise FullResultSet if self.negated and mql: - mql = {"$not": mql} + mql = {"$nor": [mql]} if as_path else {"$not": [mql]} return mql +def nothing_node(self, compiler, connection, as_path=None): # noqa: ARG001 + return self.as_sql(compiler, connection) + + def register_nodes(): ExtraWhere.as_mql = extra_where Join.as_mql = join - NothingNode.as_mql = NothingNode.as_sql + NothingNode.as_mql = nothing_node WhereNode.as_mql = where_node diff --git a/django_mongodb_backend/query_conversion/__init__.py b/django_mongodb_backend/query_conversion/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/django_mongodb_backend/query_conversion/expression_converters.py b/django_mongodb_backend/query_conversion/expression_converters.py deleted file mode 100644 index dc9b7df5b..000000000 --- a/django_mongodb_backend/query_conversion/expression_converters.py +++ /dev/null @@ -1,172 +0,0 @@ -class BaseConverter: - """Base class for $expr to $match converters.""" - - @classmethod - def convert(cls, expr): - raise NotImplementedError("Subclasses must implement this method.") - - @classmethod - def is_simple_value(cls, value): - """Is the value is a simple type (not a dict)?""" - if value is None: - return True - if isinstance(value, str) and value.startswith("$"): - return False - if isinstance(value, (list, tuple, set)): - return all(cls.is_simple_value(v) for v in value) - # TODO: Support `$getField` conversion. - return not isinstance(value, dict) - - -class BinaryConverter(BaseConverter): - """ - Base class for converting binary operations. - - For example: - "$expr": { - {"$gt": ["$price", 100]} - } - is converted to: - {"price": {"$gt": 100}} - """ - - operator: str - - @classmethod - def convert(cls, args): - if isinstance(args, list) and len(args) == 2: - field_expr, value = args - # Check if first argument is a simple field reference. - if ( - isinstance(field_expr, str) - and field_expr.startswith("$") - and cls.is_simple_value(value) - ): - field_name = field_expr[1:] # Remove the $ prefix. - if cls.operator == "$eq": - return {field_name: value} - return {field_name: {cls.operator: value}} - return None - - -class EqConverter(BinaryConverter): - """ - Convert $eq operation to a $match query. - - For example: - "$expr": { - {"$eq": ["$status", "active"]} - } - is converted to: - {"status": "active"} - """ - - operator = "$eq" - - -class GtConverter(BinaryConverter): - operator = "$gt" - - -class GteConverter(BinaryConverter): - operator = "$gte" - - -class LtConverter(BinaryConverter): - operator = "$lt" - - -class LteConverter(BinaryConverter): - operator = "$lte" - - -class InConverter(BaseConverter): - """ - Convert $in operation to a $match query. - - For example: - "$expr": { - {"$in": ["$category", ["electronics", "books"]]} - } - is converted to: - {"category": {"$in": ["electronics", "books"]}} - """ - - @classmethod - def convert(cls, in_args): - if isinstance(in_args, list) and len(in_args) == 2: - field_expr, values = in_args - # Check if first argument is a simple field reference. - if isinstance(field_expr, str) and field_expr.startswith("$"): - field_name = field_expr[1:] # Remove the $ prefix. - if isinstance(values, (list, tuple, set)) and all( - cls.is_simple_value(v) for v in values - ): - return {field_name: {"$in": values}} - return None - - -class LogicalConverter(BaseConverter): - """ - Base class for converting logical operations to a $match query. - - For example: - "$expr": { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - ] - } - is converted to: - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - ] - """ - - @classmethod - def convert(cls, combined_conditions): - if isinstance(combined_conditions, list): - optimized_conditions = [] - for condition in combined_conditions: - if isinstance(condition, dict) and len(condition) == 1: - if optimized_condition := convert_expression(condition): - optimized_conditions.append(optimized_condition) - else: - # Any failure should stop optimization. - return None - if optimized_conditions: - return {cls._logical_op: optimized_conditions} - return None - - -class OrConverter(LogicalConverter): - _logical_op = "$or" - - -class AndConverter(LogicalConverter): - _logical_op = "$and" - - -OPTIMIZABLE_OPS = { - "$eq": EqConverter, - "$in": InConverter, - "$and": AndConverter, - "$or": OrConverter, - "$gt": GtConverter, - "$gte": GteConverter, - "$lt": LtConverter, - "$lte": LteConverter, -} - - -def convert_expression(expr): - """ - Optimize MQL by converting an $expr condition to $match. Return the $match - MQL, or None if not optimizable. - """ - if isinstance(expr, dict) and len(expr) == 1: - op = next(iter(expr.keys())) - if op in OPTIMIZABLE_OPS: - return OPTIMIZABLE_OPS[op].convert(expr[op]) - return None diff --git a/django_mongodb_backend/query_conversion/query_optimizer.py b/django_mongodb_backend/query_conversion/query_optimizer.py deleted file mode 100644 index 368c89504..000000000 --- a/django_mongodb_backend/query_conversion/query_optimizer.py +++ /dev/null @@ -1,73 +0,0 @@ -from .expression_converters import convert_expression - - -def convert_expr_to_match(query): - """ - Optimize an MQL query by converting conditions into a list of $match - stages. - """ - if "$expr" not in query: - return [query] - if query["$expr"] == {}: - return [{"$match": {}}] - return _process_expression(query["$expr"]) - - -def _process_expression(expr): - """Process an expression and extract optimizable conditions.""" - match_conditions = [] - remaining_conditions = [] - if isinstance(expr, dict): - has_and = "$and" in expr - has_or = "$or" in expr - # Do a top-level check for $and or $or because these should inform. - # If they fail, they should failover to a remaining conditions list. - # There's probably a better way to do this. - if has_and: - and_match_conditions = _process_logical_conditions("$and", expr["$and"]) - match_conditions.extend(and_match_conditions) - if has_or: - or_match_conditions = _process_logical_conditions("$or", expr["$or"]) - match_conditions.extend(or_match_conditions) - if not has_and and not has_or: - # Process single condition. - if optimized := convert_expression(expr): - match_conditions.append({"$match": optimized}) - else: - remaining_conditions.append({"$match": {"$expr": expr}}) - else: - # Can't optimize. - remaining_conditions.append({"$expr": expr}) - return match_conditions + remaining_conditions - - -def _process_logical_conditions(logical_op, logical_conditions): - """Process conditions within a logical array.""" - optimized_conditions = [] - match_conditions = [] - remaining_conditions = [] - for condition in logical_conditions: - _remaining_conditions = [] - if isinstance(condition, dict): - if optimized := convert_expression(condition): - optimized_conditions.append(optimized) - else: - _remaining_conditions.append(condition) - else: - _remaining_conditions.append(condition) - if _remaining_conditions: - # Any expressions that can't be optimized must remain in a $expr - # that preserves the logical operator. - if len(_remaining_conditions) > 1: - remaining_conditions.append({"$expr": {logical_op: _remaining_conditions}}) - else: - remaining_conditions.append({"$expr": _remaining_conditions[0]}) - if optimized_conditions: - optimized_conditions.extend(remaining_conditions) - if len(optimized_conditions) > 1: - match_conditions.append({"$match": {logical_op: optimized_conditions}}) - else: - match_conditions.append({"$match": optimized_conditions[0]}) - else: - match_conditions.append({"$match": {logical_op: remaining_conditions}}) - return match_conditions diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 4b744241e..75c3c099f 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -1,13 +1,14 @@ from django.core.exceptions import FullResultSet from django.db.models.aggregates import Aggregate -from django.db.models.expressions import Value +from django.db.models.expressions import CombinedExpression, Value +from django.db.models.sql.query import Query def is_direct_value(node): return not hasattr(node, "as_sql") -def process_lhs(node, compiler, connection): +def process_lhs(node, compiler, connection, as_path=False): if not hasattr(node, "lhs"): # node is a Func or Expression, possibly with multiple source expressions. result = [] @@ -15,27 +16,30 @@ def process_lhs(node, compiler, connection): if expr is None: continue try: - result.append(expr.as_mql(compiler, connection)) + result.append(expr.as_mql(compiler, connection, as_path=as_path)) except FullResultSet: - result.append(Value(True).as_mql(compiler, connection)) + result.append(Value(True).as_mql(compiler, connection, as_path=as_path)) if isinstance(node, Aggregate): return result[0] return result # node is a Transform with just one source expression, aliased as "lhs". if is_direct_value(node.lhs): return node - return node.lhs.as_mql(compiler, connection) + return node.lhs.as_mql(compiler, connection, as_path=as_path) -def process_rhs(node, compiler, connection): +def process_rhs(node, compiler, connection, as_path=False): rhs = node.rhs if hasattr(rhs, "as_mql"): if getattr(rhs, "subquery", False) and hasattr(node, "get_subquery_wrapping_pipeline"): value = rhs.as_mql( - compiler, connection, get_wrapping_pipeline=node.get_subquery_wrapping_pipeline + compiler, + connection, + get_wrapping_pipeline=node.get_subquery_wrapping_pipeline, + as_path=as_path, ) else: - value = rhs.as_mql(compiler, connection) + value = rhs.as_mql(compiler, connection, as_path=as_path) else: _, value = node.process_rhs(compiler, connection) lookup_name = node.lookup_name @@ -47,7 +51,40 @@ def process_rhs(node, compiler, connection): return value -def regex_match(field, regex_vals, insensitive=False): +def regex_expr(field, regex_vals, insensitive=False): regex = {"$concat": regex_vals} if isinstance(regex_vals, tuple) else regex_vals options = "i" if insensitive else "" return {"$regexMatch": {"input": field, "regex": regex, "options": options}} + + +def regex_match(field, regex, insensitive=False): + options = "i" if insensitive else "" + return {field: {"$regex": regex, "$options": options}} + + +def is_constant_value(value): + if isinstance(value, CombinedExpression): + # Temporary: treat all CombinedExpressions as non-constant until + # constant cases are handled + return False + if isinstance(value, list): + return all(map(is_constant_value, value)) + if is_direct_value(value): + return True + if hasattr(value, "get_source_expressions"): + # Temporary: similar limitation as above, sub-expressions should be + # resolved in the future + simple_sub_expressions = all(map(is_constant_value, value.get_source_expressions())) + else: + simple_sub_expressions = True + return ( + simple_sub_expressions + and isinstance(value, Value) + and not ( + isinstance(value, Query) + or value.contains_aggregate + or value.contains_over_clause + or value.contains_column_references + or value.contains_subquery + ) + ) diff --git a/django_mongodb_backend/test.py b/django_mongodb_backend/test.py index 561832a15..ee35b4e21 100644 --- a/django_mongodb_backend/test.py +++ b/django_mongodb_backend/test.py @@ -1,6 +1,6 @@ """Not a public API.""" -from bson import SON, ObjectId +from bson import SON, Decimal128, ObjectId class MongoTestCaseMixin: @@ -16,6 +16,6 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline): self.assertEqual(operator, "aggregate") self.assertEqual(collection, expected_collection) self.assertEqual( - eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId}, {}), # noqa: S307 + eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307 expected_pipeline, ) diff --git a/tests/expression_converter_/__init__.py b/tests/expression_converter_/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/expression_converter_/test_match_conversion.py b/tests/expression_converter_/test_match_conversion.py deleted file mode 100644 index e78e5c0cc..000000000 --- a/tests/expression_converter_/test_match_conversion.py +++ /dev/null @@ -1,215 +0,0 @@ -from django.test import SimpleTestCase - -from django_mongodb_backend.query_conversion.query_optimizer import convert_expr_to_match - - -class ConvertExprToMatchTests(SimpleTestCase): - def assertOptimizerEqual(self, input, expected): - result = convert_expr_to_match(input) - self.assertEqual(result, expected) - - def test_multiple_optimizable_conditions(self): - expr = { - "$expr": { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$gte": ["$price", 50]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"price": {"$gte": 50}}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_mixed_optimizable_and_non_optimizable_conditions(self): - expr = { - "$expr": { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - {"$in": ["$category", ["electronics"]]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics"]}}, - {"$expr": {"$gt": ["$price", "$min_price"]}}, - ], - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_non_optimizable_condition(self): - expr = {"$expr": {"$gt": ["$price", "$min_price"]}} - expected = [ - { - "$match": { - "$expr": {"$gt": ["$price", "$min_price"]}, - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_nested_logical_conditions(self): - expr = { - "$expr": { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$and": [{"$eq": ["$verified", True]}, {"$lte": ["$price", 50]}]}, - ] - } - } - expected = [ - { - "$match": { - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"$and": [{"verified": True}, {"price": {"$lte": 50}}]}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_complex_nested_with_non_optimizable_parts(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$views", 1000]}, - ] - }, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - ] - } - } - expected = [ - { - "$match": { - "$and": [ - { - "$or": [ - {"status": "active"}, - {"views": {"$gt": 1000}}, - ] - }, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"$expr": {"$gt": ["$price", "$min_price"]}}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_london_in_case(self): - expr = {"$expr": {"$in": ["$author_city", ["London"]]}} - expected = [{"$match": {"author_city": {"$in": ["London"]}}}] - self.assertOptimizerEqual(expr, expected) - - def test_deeply_nested_logical_operators(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "standard"]}, - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - }, - {"$eq": ["$active", True]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - { - "$or": [ - {"type": "premium"}, - { - "$and": [ - {"type": "standard"}, - {"region": {"$in": ["US", "CA"]}}, - ] - }, - ] - }, - {"active": True}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_deeply_nested_logical_operator_with_variable(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "$$standard"]}, # Not optimizable - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - }, - {"$eq": ["$active", True]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"active": True}, - { - "$expr": { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "$$standard"]}, - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - } - }, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) diff --git a/tests/expression_converter_/test_op_expressions.py b/tests/expression_converter_/test_op_expressions.py deleted file mode 100644 index ce4caf2d4..000000000 --- a/tests/expression_converter_/test_op_expressions.py +++ /dev/null @@ -1,233 +0,0 @@ -import datetime -from uuid import UUID - -from bson import Decimal128 -from django.test import SimpleTestCase - -from django_mongodb_backend.query_conversion.expression_converters import convert_expression - - -class ConversionTestCase(SimpleTestCase): - CONVERTIBLE_TYPES = { - "int": 42, - "float": 3.14, - "decimal128": Decimal128("3.14"), - "boolean": True, - "NoneType": None, - "string": "string", - "datetime": datetime.datetime.now(datetime.timezone.utc), - "duration": datetime.timedelta(days=5, hours=3), - "uuid": UUID("12345678123456781234567812345678"), - } - - def assertConversionEqual(self, input, expected): - result = convert_expression(input) - self.assertEqual(result, expected) - - def assertNotOptimizable(self, input): - result = convert_expression(input) - self.assertIsNone(result) - - def _test_conversion_various_types(self, conversion_test): - for _type, val in self.CONVERTIBLE_TYPES.items(): - with self.subTest(_type=_type, val=val): - conversion_test(val) - - -class ExpressionTests(ConversionTestCase): - def test_non_dict(self): - self.assertNotOptimizable(["$status", "active"]) - - def test_empty_dict(self): - self.assertNotOptimizable({}) - - -class EqTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$eq": ["$status", "active"]}, {"status": "active"}) - - def test_no_conversion_non_string_field(self): - self.assertNotOptimizable({"$eq": [123, "active"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$eq": ["$status", {"$gt": 5}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$eq": ["$age", _type]}, {"age": _type}) - - def _test_conversion_valid_array_type(self, _type): - self.assertConversionEqual({"$eq": ["$age", _type]}, {"age": _type}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - def test_conversion_various_array_types(self): - self._test_conversion_various_types(self._test_conversion_valid_array_type) - - -class InTests(ConversionTestCase): - def test_conversion(self): - expr = {"$in": ["$category", ["electronics", "books", "clothing"]]} - expected = {"category": {"$in": ["electronics", "books", "clothing"]}} - self.assertConversionEqual(expr, expected) - - def test_no_conversion_non_string_field(self): - self.assertNotOptimizable({"$in": [123, ["electronics", "books"]]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$in": ["$status", [{"bad": "val"}]]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$in": ["$age", [_type]]}, {"age": {"$in": [_type]}}) - - def test_conversion_various_types(self): - for _type, val in self.CONVERTIBLE_TYPES.items(): - with self.subTest(_type=_type, val=val): - self._test_conversion_valid_type(val) - - -class LogicalTests(ConversionTestCase): - def test_and(self): - expr = { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - ] - } - expected = { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - ] - } - self.assertConversionEqual(expr, expected) - - def test_or(self): - expr = { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - ] - } - expected = { - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - ] - } - self.assertConversionEqual(expr, expected) - - def test_or_failure(self): - expr = { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - { - "$and": [ - {"verified": True}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - ] - }, - ] - } - self.assertNotOptimizable(expr) - - def test_mixed(self): - expr = { - "$and": [ - { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$views", 1000]}, - ] - }, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$lte": ["$price", 2000]}, - ] - } - expected = { - "$and": [ - {"$or": [{"status": "active"}, {"views": {"$gt": 1000}}]}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"price": {"$lte": 2000}}, - ] - } - self.assertConversionEqual(expr, expected) - - -class GtTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$gt": ["$price", 100]}, {"price": {"$gt": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$gt": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$gt": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$gt": ["$price", _type]}, {"price": {"$gt": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class GteTests(ConversionTestCase): - def test_conversion(self): - expr = {"$gte": ["$price", 100]} - expected = {"price": {"$gte": 100}} - self.assertConversionEqual(expr, expected) - - def test_no_conversion_non_simple_field(self): - expr = {"$gte": ["$price", "$min_price"]} - self.assertNotOptimizable(expr) - - def test_no_conversion_dict_value(self): - expr = {"$gte": ["$price", {}]} - self.assertNotOptimizable(expr) - - def _test_conversion_valid_type(self, _type): - expr = {"$gte": ["$price", _type]} - expected = {"price": {"$gte": _type}} - self.assertConversionEqual(expr, expected) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class LtTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$lt": ["$price", 100]}, {"price": {"$lt": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$lt": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$lt": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$lt": ["$price", _type]}, {"price": {"$lt": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class LteTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$lte": ["$price", 100]}, {"price": {"$lte": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$lte": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$lte": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$lte": ["$price", _type]}, {"price": {"$lte": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) diff --git a/tests/lookup_/tests.py b/tests/lookup_/tests.py index 6fce89942..6166e1003 100644 --- a/tests/lookup_/tests.py +++ b/tests/lookup_/tests.py @@ -1,3 +1,4 @@ +from bson import SON from django.test import TestCase from django_mongodb_backend.test import MongoTestCaseMixin @@ -5,12 +6,12 @@ from .models import Book, Number -class NumericLookupTests(TestCase): +class NumericLookupTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.objs = Number.objects.bulk_create(Number(num=x) for x in range(5)) # Null values should be excluded in less than queries. - Number.objects.create() + cls.null_number = Number.objects.create() def test_lt(self): self.assertQuerySetEqual(Number.objects.filter(num__lt=3), self.objs[:3]) @@ -18,6 +19,20 @@ def test_lt(self): def test_lte(self): self.assertQuerySetEqual(Number.objects.filter(num__lte=3), self.objs[:4]) + def test_empty_range(self): + with self.assertNumQueries(0): + self.assertQuerySetEqual(Number.objects.filter(num__range=[3, 1]), []) + + def test_full_range(self): + with self.assertNumQueries(1) as ctx: + self.assertQuerySetEqual( + Number.objects.filter(num__range=[None, None]), [self.null_number, *self.objs] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "lookup__number", [{"$addFields": {"num": "$num"}}, {"$sort": SON([("num", 1)])}] + ) + class RegexTests(MongoTestCaseMixin, TestCase): def test_mql(self): @@ -29,15 +44,7 @@ def test_mql(self): self.assertAggregateQuery( query, "lookup__book", - [ - { - "$match": { - "$expr": { - "$regexMatch": {"input": "$title", "regex": "Moby Dick", "options": ""} - } - } - } - ], + [{"$match": {"title": {"$regex": "Moby Dick", "$options": ""}}}], ) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index a94090ecf..42d1d7a30 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -10,12 +10,14 @@ Max, OuterRef, Sum, + Value, ) from django.test import SimpleTestCase, TestCase from django.test.utils import isolate_apps from django_mongodb_backend.fields import EmbeddedModelField from django_mongodb_backend.models import EmbeddedModel +from django_mongodb_backend.test import MongoTestCaseMixin from .models import ( Address, @@ -108,7 +110,7 @@ def test_pre_save(self): self.assertGreater(obj.data.auto_now, auto_now_two) -class QueryingTests(TestCase): +class QueryingTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.objs = [ @@ -122,23 +124,354 @@ def setUpTestData(cls): for x in range(6) ] - def test_exact(self): - self.assertCountEqual(Holder.objects.filter(data__integer=3), [self.objs[3]]) + def test_exact_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer=Value(4) - 1), [self.objs[3]]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$eq": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) - def test_lt(self): - self.assertCountEqual(Holder.objects.filter(data__integer__lt=3), self.objs[:3]) + def test_exact_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer=3), [self.objs[3]]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery(query, "model_fields__holder", [{"$match": {"data.integer": 3}}]) - def test_lte(self): - self.assertCountEqual(Holder.objects.filter(data__integer__lte=3), self.objs[:4]) + def test_lt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__lt=Value(4) - 1), self.objs[:3] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$lt": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + }, + { + "$not": { + "$or": [ + { + "$eq": [ + { + "$type": { + "$getField": { + "input": "$data", + "field": "integer", + } + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$getField": { + "input": "$data", + "field": "integer", + } + }, + None, + ] + }, + ] + } + }, + ] + } + } + } + ], + ) - def test_gt(self): - self.assertCountEqual(Holder.objects.filter(data__integer__gt=3), self.objs[4:]) + def test_lt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__lt=3), self.objs[:3]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$and": [ + {"data.integer": {"$lt": 3}}, + { + "$and": [ + {"data.integer": {"$exists": True}}, + {"data.integer": {"$ne": None}}, + ] + }, + ] + } + } + ], + ) - def test_gte(self): - self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) + def test_lte_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__lte=Value(4) - 1), self.objs[:4] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$lte": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + }, + { + "$not": { + "$or": [ + { + "$eq": [ + { + "$type": { + "$getField": { + "input": "$data", + "field": "integer", + } + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$getField": { + "input": "$data", + "field": "integer", + } + }, + None, + ] + }, + ] + } + }, + ] + } + } + } + ], + ) + + def test_lte_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__lte=3), self.objs[:4]) + query = ctx.captured_queries[0]["sql"] - def test_range(self): - self.assertCountEqual(Holder.objects.filter(data__integer__range=(2, 4)), self.objs[2:5]) + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$and": [ + {"data.integer": {"$lte": 3}}, + { + "$and": [ + {"data.integer": {"$exists": True}}, + {"data.integer": {"$ne": None}}, + ] + }, + ] + } + } + ], + ) + + def test_gt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__gt=Value(4) - 1), self.objs[4:] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$gt": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_gt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__gt=3), self.objs[4:]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__holder", [{"$match": {"data.integer": {"$gt": 3}}}] + ) + + def test_gte_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__gte=Value(4) - 1), self.objs[3:] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$gte": [ + {"$getField": {"input": "$data", "field": "integer"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_gte_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__holder", [{"$match": {"data.integer": {"$gte": 3}}}] + ) + + def test_range_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__range=(2, Value(5) - 1)), self.objs[2:5] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$or": [ + { + "$or": [ + {"$eq": [{"$type": {"$literal": 2}}, "missing"]}, + {"$eq": [{"$literal": 2}, None]}, + ] + }, + { + "$gte": [ + { + "$getField": { + "input": "$data", + "field": "integer", + } + }, + {"$literal": 2}, + ] + }, + ] + }, + { + "$or": [ + { + "$or": [ + { + "$eq": [ + { + "$type": { + "$subtract": [ + {"$literal": 5}, + {"$literal": 1}, + ] + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$subtract": [ + {"$literal": 5}, + {"$literal": 1}, + ] + }, + None, + ] + }, + ] + }, + { + "$lte": [ + { + "$getField": { + "input": "$data", + "field": "integer", + } + }, + {"$subtract": [{"$literal": 5}, {"$literal": 1}]}, + ] + }, + ] + }, + ] + } + } + } + ], + ) + + def test_range_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__range=(2, 4)), self.objs[2:5] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [{"$match": {"$and": [{"data.integer": {"$gte": 2}}, {"data.integer": {"$lte": 4}}]}}], + ) def test_exact_decimal(self): # EmbeddedModelField lookups call @@ -225,6 +558,17 @@ def test_nested(self): ) self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) + def test_filter_by_simple_annotate(self): + obj = Book.objects.create( + author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY")) + ) + with self.assertNumQueries(1) as ctx: + book_from_ny = ( + Book.objects.annotate(city=F("author__address__city")).filter(city="NYC").first() + ) + self.assertCountEqual(book_from_ny.city, obj.author.address.city) + self.assertIn("{'$match': {'author.address.city': 'NYC'}}", ctx.captured_queries[0]["sql"]) + class ArrayFieldTests(TestCase): @classmethod diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index 363344f3a..83b425e9f 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -10,6 +10,7 @@ from django_mongodb_backend.fields import ArrayField, EmbeddedModelArrayField from django_mongodb_backend.models import EmbeddedModel +from django_mongodb_backend.test import MongoTestCaseMixin from .models import Artifact, Audit, Exhibit, Movie, Restoration, Review, Section, Tour @@ -63,7 +64,7 @@ def test_save_load_null(self): self.assertIsNone(movie.reviews) -class QueryingTests(TestCase): +class QueryingTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.egypt = Exhibit.objects.create( @@ -156,23 +157,203 @@ def setUpTestData(cls): cls.audit_2 = Audit.objects.create(section_number=2, reviewed=True) cls.audit_3 = Audit.objects.create(section_number=5, reviewed=False) - def test_exact(self): - self.assertCountEqual( - Exhibit.objects.filter(sections__number=1), [self.egypt, self.wonders] + def test_exact_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__number=Value(2) - 1), [self.egypt, self.wonders] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": "$sections", + "as": "item", + "in": { + "$eq": [ + "$$item.number", + { + "$subtract": [ + {"$literal": 2}, + {"$literal": 1}, + ] + }, + ] + }, + } + }, + [], + ] + } + } + } + } + ], ) - def test_array_index(self): - self.assertCountEqual( - Exhibit.objects.filter(sections__0__number=1), - [self.egypt, self.wonders], + def test_exact_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__number=1), [self.egypt, self.wonders] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__exhibit", [{"$match": {"sections.number": 1}}] ) - def test_nested_array_index(self): - self.assertCountEqual( - Exhibit.objects.filter( - main_section__artifacts__restorations__0__restored_by="Zacarias" - ), - [self.lost_empires], + def test_array_index_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__0__number=Value(2) - 1), + [self.egypt, self.wonders], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$eq": [ + { + "$getField": { + "input": {"$arrayElemAt": ["$sections", 0]}, + "field": "number", + } + }, + {"$subtract": [{"$literal": 2}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_array_index_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__0__number=1), + [self.egypt, self.wonders], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__exhibit", [{"$match": {"sections.0.number": 1}}] + ) + + def test_nested_array_index_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by="Zacarias" + ), + [self.lost_empires], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": { + "$getField": { + "input": "$main_section", + "field": "artifacts", + } + }, + "as": "item", + "in": { + "$eq": [ + { + "$getField": { + "input": { + "$arrayElemAt": [ + "$$item.restorations", + 0, + ] + }, + "field": "restored_by", + } + }, + "Zacarias", + ] + }, + } + }, + [], + ] + } + } + } + } + ], + ) + + def test_nested_array_index_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by="Zacarias" + ), + [self.lost_empires], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": { + "$getField": { + "input": "$main_section", + "field": "artifacts", + } + }, + "as": "item", + "in": { + "$eq": [ + { + "$getField": { + "input": { + "$arrayElemAt": [ + "$$item.restorations", + 0, + ] + }, + "field": "restored_by", + } + }, + "Zacarias", + ] + }, + } + }, + [], + ] + } + } + } + } + ], ) def test_array_slice(self): @@ -186,7 +367,21 @@ def test_filter_unsupported_lookups_in_json(self): kwargs = {f"main_section__artifacts__metadata__origin__{lookup}": ["Pergamon", "Egypt"]} with CaptureQueriesContext(connection) as captured_queries: self.assertCountEqual(Exhibit.objects.filter(**kwargs), []) - self.assertIn(f"'field': '{lookup}'", captured_queries[0]["sql"]) + query = captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + f"main_section.artifacts.metadata.origin.{lookup}": [ + "Pergamon", + "Egypt", + ] + } + } + ], + ) def test_len(self): self.assertCountEqual(Exhibit.objects.filter(sections__len=10), []) @@ -284,10 +479,19 @@ def test_nested_lookup(self): with self.assertRaisesMessage(ValueError, msg): Exhibit.objects.filter(sections__artifacts__name="") - def test_foreign_field_exact(self): + def test_foreign_field_exact_path(self): + """Querying from a foreign key to an EmbeddedModelArrayField.""" + with self.assertNumQueries(1) as ctx: + qs = Tour.objects.filter(exhibit__sections__number=1) + self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + self.assertNotIn("anyElementTrue", ctx.captured_queries[0]["sql"]) + + def test_foreign_field_exact_expr(self): """Querying from a foreign key to an EmbeddedModelArrayField.""" - qs = Tour.objects.filter(exhibit__sections__number=1) - self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + with self.assertNumQueries(1) as ctx: + qs = Tour.objects.filter(exhibit__sections__number=Value(2) - Value(1)) + self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + self.assertIn("anyElementTrue", ctx.captured_queries[0]["sql"]) def test_foreign_field_with_slice(self): qs = Tour.objects.filter(exhibit__sections__0_2__number__in=[1, 2]) diff --git a/tests/queries_/test_mql.py b/tests/queries_/test_mql.py index ffd1e2e32..e8837bf8a 100644 --- a/tests/queries_/test_mql.py +++ b/tests/queries_/test_mql.py @@ -11,9 +11,7 @@ class MQLTests(MongoTestCaseMixin, TestCase): def test_all(self): with self.assertNumQueries(1) as ctx: list(Author.objects.all()) - self.assertAggregateQuery( - ctx.captured_queries[0]["sql"], "queries__author", [{"$match": {}}] - ) + self.assertAggregateQuery(ctx.captured_queries[0]["sql"], "queries__author", []) def test_join(self): with self.assertNumQueries(1) as ctx: @@ -29,12 +27,14 @@ def test_join(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Bob"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Bob"}, + ] } } ], @@ -62,12 +62,14 @@ def test_filter_on_local_and_related_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "John"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "John"}, + ] } } ], @@ -123,22 +125,19 @@ def test_filter_on_self_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - { - "$and": [ - { - "$eq": [ - "$group_id", - ObjectId("6891ff7822e475eddc20f159"), - ] - }, - {"$eq": ["$name", "parent"]}, - ] - }, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + { + "$and": [ + {"group_id": ObjectId("6891ff7822e475eddc20f159")}, + {"name": "parent"}, + ] + }, + ] } } ], @@ -171,17 +170,16 @@ def test_filter_on_reverse_foreignkey_relation(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - { - "$eq": [ - "$status", - ObjectId("6891ff7822e475eddc20f159"), + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} ] - }, - ] - } + } + }, + {"status": ObjectId("6891ff7822e475eddc20f159")}, + ] } } ], @@ -215,17 +213,16 @@ def test_filter_on_local_and_nested_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - { - "$eq": [ - "$status", - ObjectId("6891ff7822e475eddc20f159"), + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} ] - }, - ] - } + } + }, + {"status": ObjectId("6891ff7822e475eddc20f159")}, + ] } } ], @@ -240,12 +237,14 @@ def test_filter_on_local_and_nested_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "My Order"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "My Order"}, + ] } } ], @@ -276,6 +275,7 @@ def test_negated_related_filter_is_not_pushable(self): [ { "$lookup": { + "as": "queries__author", "from": "queries__author", "let": {"parent__field__0": "$author_id"}, "pipeline": [ @@ -285,11 +285,10 @@ def test_negated_related_filter_is_not_pushable(self): } } ], - "as": "queries__author", } }, {"$unwind": "$queries__author"}, - {"$match": {"$expr": {"$not": {"$eq": ["$queries__author.name", "John"]}}}}, + {"$match": {"$nor": [{"queries__author.name": "John"}]}}, ], ) @@ -341,21 +340,25 @@ def test_push_equality_between_parent_and_child_fields(self): [ { "$lookup": { + "as": "queries__orderitem", "from": "queries__orderitem", "let": {"parent__field__0": "$_id", "parent__field__1": "$_id"}, "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - {"$eq": ["$status", "$$parent__field__1"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} + ] + } + }, + {"$expr": {"$eq": ["$status", "$$parent__field__1"]}}, + ] } } ], - "as": "queries__orderitem", } }, {"$unwind": "$queries__orderitem"}, @@ -398,12 +401,14 @@ def test_simple_related_filter_is_pushed(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Alice"}, + ] } } ], @@ -416,6 +421,7 @@ def test_simple_related_filter_is_pushed(self): ) def test_subquery_join_is_pushed(self): + # TODO; isn't fully OPTIMIZED with self.assertNumQueries(1) as ctx: list(Library.objects.filter(~models.Q(readers__name="Alice"))) @@ -436,12 +442,21 @@ def test_subquery_join_is_pushed(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [ + { + "$eq": [ + "$$parent__field__0", + "$_id", + ] + } + ] + } + }, + {"name": "Alice"}, + ] } } ], @@ -480,21 +495,28 @@ def test_subquery_join_is_pushed(self): }, { "$match": { - "$expr": { - "$not": { - "$eq": [ - { - "$not": { - "$or": [ - {"$eq": [{"$type": "$__subquery0.a"}, "missing"]}, - {"$eq": ["$__subquery0.a", None]}, - ] - } - }, - True, - ] + "$nor": [ + { + "$expr": { + "$eq": [ + { + "$not": { + "$or": [ + { + "$eq": [ + {"$type": "$__subquery0.a"}, + "missing", + ] + }, + {"$eq": ["$__subquery0.a", None]}, + ] + } + }, + True, + ] + } } - } + ] } }, ], @@ -531,12 +553,14 @@ def test_filter_on_local_and_related_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Alice"}, + ] } } ],