diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 1c727039d..b7a57bec0 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -10,14 +10,16 @@ from django.db.models.expressions import Case, Col, OrderBy, Ref, Value, When from django.db.models.functions.comparison import Coalesce from django.db.models.functions.math import Power -from django.db.models.lookups import IsNull +from django.db.models.lookups import IsNull, Lookup from django.db.models.sql import compiler from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, MULTI, SINGLE from django.db.models.sql.datastructures import BaseTable +from django.db.models.sql.where import AND from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING from .query import MongoQuery, wrap_database_errors +from .query_utils import is_direct_value class SQLCompiler(compiler.SQLCompiler): @@ -550,10 +552,22 @@ def get_combinator_queries(self): def get_lookup_pipeline(self): result = [] + where = self.get_where() + promote_filters = defaultdict(list) + for expr in where.children if where and where.connector == AND else (): + if ( + isinstance(expr, Lookup) + and isinstance(expr.lhs, Col) + and (is_direct_value(expr.rhs) or isinstance(expr.rhs, Value)) + ): + promote_filters[expr.lhs.alias].append(expr) + for alias in tuple(self.query.alias_map): if not self.query.alias_refcount[alias] or self.collection_name == alias: continue - result += self.query.alias_map[alias].as_mql(self, self.connection) + result += self.query.alias_map[alias].as_mql( + self, self.connection, promote_filters[alias] + ) return result def _get_aggregate_expressions(self, expr): diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index d59bc1631..ef1f1a040 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -129,25 +129,12 @@ def extra_where(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("QuerySet.extra() is not supported on MongoDB.") -def join(self, compiler, connection): - lookup_pipeline = [] - lhs_fields = [] - rhs_fields = [] - # Add a join condition for each pair of joining fields. - parent_template = "parent__field__" - for lhs, rhs in self.join_fields: - lhs, rhs = connection.ops.prepare_join_on_clause( - self.parent_alias, lhs, compiler.collection_name, rhs - ) - lhs_fields.append(lhs.as_mql(compiler, connection)) - # In the lookup stage, the reference to this column doesn't include - # the collection name. - rhs_fields.append(rhs.as_mql(compiler, connection)) - # Handle any join conditions besides matching field pairs. - extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) - if extra: +def join(self, compiler, connection, pushed_expressions=None): + def _get_reroot_replacements(expressions): + if not expressions: + return None columns = [] - for expr in extra.leaves(): + for expr in expressions: # Determine whether the column needs to be transformed or rerouted # as part of the subquery. for hand_side in ["lhs", "rhs"]: @@ -165,7 +152,9 @@ def join(self, compiler, connection): # based on their rerouted positions in the join pipeline. replacements = {} for col, parent_pos in columns: - column_target = Col(compiler.collection_name, expr.output_field.__class__()) + target = col.target.clone() + target.remote_field = col.target.remote_field + column_target = Col(compiler.collection_name, target) if parent_pos is not None: target_col = f"${parent_template}{parent_pos}" column_target.target.db_column = target_col @@ -173,10 +162,37 @@ def join(self, compiler, connection): else: column_target.target = col.target replacements[col] = column_target - # Apply the transformed expressions in the extra condition. + return replacements + + lookup_pipeline = [] + lhs_fields = [] + rhs_fields = [] + # Add a join condition for each pair of joining fields. + parent_template = "parent__field__" + for lhs, rhs in self.join_fields: + lhs, rhs = connection.ops.prepare_join_on_clause( + self.parent_alias, lhs, compiler.collection_name, rhs + ) + lhs_fields.append(lhs.as_mql(compiler, connection)) + # In the lookup stage, the reference to this column doesn't include + # the collection name. + rhs_fields.append(rhs.as_mql(compiler, connection)) + # Handle any join conditions besides matching field pairs. + extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) + + if extra: + replacements = _get_reroot_replacements(extra.leaves()) extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)] else: extra_condition = [] + if self.join_type == INNER: + rerooted_replacement = _get_reroot_replacements(pushed_expressions) + resolved_pushed_expressions = [ + expr.replace_expressions(rerooted_replacement).as_mql(compiler, connection) + for expr in pushed_expressions + ] + else: + resolved_pushed_expressions = [] lookup_pipeline = [ { @@ -204,6 +220,7 @@ def join(self, compiler, connection): for i, field in enumerate(rhs_fields) ] + extra_condition + + resolved_pushed_expressions } } } diff --git a/tests/queries_/test_mql.py b/tests/queries_/test_mql.py index d61e5839d..d12ea9602 100644 --- a/tests/queries_/test_mql.py +++ b/tests/queries_/test_mql.py @@ -20,7 +20,8 @@ def test_join(self): "{'$lookup': {'from': 'queries__author', " "'let': {'parent__field__0': '$author_id'}, " "'pipeline': [{'$match': {'$expr': " - "{'$and': [{'$eq': ['$$parent__field__0', '$_id']}]}}}], 'as': 'queries__author'}}, " + "{'$and': [{'$eq': ['$$parent__field__0', '$_id']}, " + "{'$eq': ['$name', 'Bob']}]}}}], 'as': 'queries__author'}}, " "{'$unwind': '$queries__author'}, " "{'$match': {'$expr': {'$eq': ['$queries__author.name', 'Bob']}}}])", )