Skip to content

Commit 2c48d11

Browse files
committed
Object-oriented approach solution
1 parent 59c8faf commit 2c48d11

File tree

14 files changed

+375
-312
lines changed

14 files changed

+375
-312
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,7 @@
88
MONGO_AGGREGATIONS = {Count: "sum"}
99

1010

11-
def aggregate(
12-
self,
13-
compiler,
14-
connection,
15-
operator=None,
16-
resolve_inner_expression=False,
17-
**extra_context, # noqa: ARG001
18-
):
11+
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
1912
if self.filter:
2013
node = self.copy()
2114
node.filter = None
@@ -31,7 +24,7 @@ def aggregate(
3124
return {f"${operator}": lhs_mql}
3225

3326

34-
def count(self, compiler, connection, resolve_inner_expression=False, **extra_context): # noqa: ARG001
27+
def count(self, compiler, connection, resolve_inner_expression=False):
3528
"""
3629
When resolve_inner_expression=True, return the MQL that resolves as a
3730
value. This is used to count different elements, so the inner values are
@@ -64,16 +57,16 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co
6457
return {"$add": [{"$size": lhs_mql}, exits_null]}
6558

6659

67-
def stddev_variance(self, compiler, connection, **extra_context):
60+
def stddev_variance(self, compiler, connection):
6861
if self.function.endswith("_SAMP"):
6962
operator = "stdDevSamp"
7063
elif self.function.endswith("_POP"):
7164
operator = "stdDevPop"
72-
return aggregate(self, compiler, connection, operator=operator, **extra_context)
65+
return aggregate(self, compiler, connection, operator=operator)
7366

7467

7568
def register_aggregates():
76-
Aggregate.as_mql = aggregate
77-
Count.as_mql = count
78-
StdDev.as_mql = stddev_variance
79-
Variance.as_mql = stddev_variance
69+
Aggregate.as_mql_expr = aggregate
70+
Count.as_mql_expr = count
71+
StdDev.as_mql_expr = stddev_variance
72+
Variance.as_mql_expr = stddev_variance

django_mongodb_backend/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
709709
# For brevity/simplicity, project {"field_name": 1}
710710
# instead of {"field_name": "$field_name"}.
711711
if isinstance(expr, Col) and name == expr.target.column and not force_expression
712-
else expr.as_mql(self, self.connection, as_path=False)
712+
else expr.as_mql(self, self.connection)
713713
)
714714
except EmptyResultSet:
715715
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)

django_mongodb_backend/expressions/builtins.py

Lines changed: 37 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Exists,
1515
ExpressionList,
1616
ExpressionWrapper,
17-
Func,
1817
NegatedExpression,
1918
OrderBy,
2019
RawSQL,
@@ -25,15 +24,12 @@
2524
Value,
2625
When,
2726
)
28-
from django.db.models.fields.json import KeyTransform
2927
from django.db.models.sql import Query
3028

31-
from django_mongodb_backend.fields.array import Array
29+
from ..query_utils import process_lhs
3230

33-
from ..query_utils import is_direct_value, process_lhs
3431

35-
36-
def case(self, compiler, connection, as_path=False):
32+
def case(self, compiler, connection):
3733
case_parts = []
3834
for case in self.cases:
3935
case_mql = {}
@@ -50,16 +46,12 @@ def case(self, compiler, connection, as_path=False):
5046
default_mql = self.default.as_mql(compiler, connection)
5147
if not case_parts:
5248
return default_mql
53-
expr = {
49+
return {
5450
"$switch": {
5551
"branches": case_parts,
5652
"default": default_mql,
5753
}
5854
}
59-
if as_path:
60-
return {"$expr": expr}
61-
62-
return expr
6355

6456

6557
def col(self, compiler, connection, as_path=False): # noqa: ARG001
@@ -100,12 +92,12 @@ def combined_expression(self, compiler, connection, as_path=False):
10092
return connection.ops.combine_expression(self.connector, expressions)
10193

10294

103-
def expression_wrapper(self, compiler, connection, as_path=False):
104-
return self.expression.as_mql(compiler, connection, as_path=as_path)
95+
def expression_wrapper_expr(self, compiler, connection):
96+
return self.expression.as_mql(compiler, connection, as_path=False)
10597

10698

107-
def negated_expression(self, compiler, connection, as_path=False):
108-
return {"$not": expression_wrapper(self, compiler, connection, as_path=as_path)}
99+
def negated_expression_expr(self, compiler, connection):
100+
return {"$not": expression_wrapper_expr(self, compiler, connection)}
109101

110102

111103
def order_by(self, compiler, connection):
@@ -178,32 +170,26 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001
178170
return f"{prefix}{refs}"
179171

180172

181-
def star(self, compiler, connection, **extra): # noqa: ARG001
173+
@property
174+
def ref_is_simple_column(self):
175+
return isinstance(self.source, Col) and self.source.alias is not None
176+
177+
178+
def star(self, compiler, connection, as_path=False): # noqa: ARG001
182179
return {"$literal": True}
183180

184181

185-
def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
186-
expr = self.query.as_mql(
182+
def subquery(self, compiler, connection, get_wrapping_pipeline=None):
183+
return self.query.as_mql(
187184
compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_path=False
188185
)
189-
if as_path:
190-
return {"$expr": expr}
191-
return expr
192186

193187

194-
def exists(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
188+
def exists(self, compiler, connection, get_wrapping_pipeline=None):
195189
try:
196-
lhs_mql = subquery(
197-
self,
198-
compiler,
199-
connection,
200-
get_wrapping_pipeline=get_wrapping_pipeline,
201-
as_path=as_path,
202-
)
190+
lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline)
203191
except EmptyResultSet:
204192
return Value(False).as_mql(compiler, connection)
205-
if as_path:
206-
return {"$expr": connection.mongo_operators_match["isnull"](lhs_mql, False)}
207193
return connection.mongo_operators_expr["isnull"](lhs_mql, False)
208194

209195

@@ -235,54 +221,37 @@ def value(self, compiler, connection, as_path=False): # noqa: ARG001
235221
return value
236222

237223

238-
@staticmethod
239-
def _is_constant_value(value):
240-
if isinstance(value, list | Array):
241-
iterable = value.get_source_expressions() if isinstance(value, Array) else value
242-
return all(_is_constant_value(e) for e in iterable)
243-
if is_direct_value(value):
244-
return True
245-
return isinstance(value, Func | Value) and not (
246-
value.contains_aggregate
247-
or value.contains_over_clause
248-
or value.contains_column_references
249-
or value.contains_subquery
250-
)
251-
252-
253-
@staticmethod
254-
def _is_simple_column(lhs):
255-
while isinstance(lhs, KeyTransform):
256-
if "." in getattr(lhs, "key_name", ""):
257-
return False
258-
lhs = lhs.lhs
259-
col = lhs.source if isinstance(lhs, Ref) else lhs
260-
# Foreign columns from parent cannot be addressed as single match
261-
return isinstance(col, Col) and col.alias is not None
262-
224+
def base_expression(self, compiler, connection, as_path=False, **extra):
225+
if (
226+
as_path
227+
and hasattr(self, "as_mql_path")
228+
and getattr(self, "is_simple_expression", lambda: False)()
229+
):
230+
return self.as_mql_path(compiler, connection, **extra)
263231

264-
def _is_simple_expression(self):
265-
return self.is_simple_column(self.lhs) and self.is_constant_value(self.rhs)
232+
expr = self.as_mql_expr(compiler, connection, **extra)
233+
return {"$expr": expr} if as_path else expr
266234

267235

268236
def register_expressions():
269-
Case.as_mql = case
237+
BaseExpression.as_mql = base_expression
238+
BaseExpression.is_simple_column = False
239+
Case.as_mql_expr = case
270240
Col.as_mql = col
241+
Col.is_simple_column = True
271242
ColPairs.as_mql = col_pairs
272-
CombinedExpression.as_mql = combined_expression
273-
Exists.as_mql = exists
243+
CombinedExpression.as_mql_expr = combined_expression
244+
Exists.as_mql_expr = exists
274245
ExpressionList.as_mql = process_lhs
275-
ExpressionWrapper.as_mql = expression_wrapper
276-
NegatedExpression.as_mql = negated_expression
277-
OrderBy.as_mql = order_by
246+
ExpressionWrapper.as_mql_expr = expression_wrapper_expr
247+
NegatedExpression.as_mql_expr = negated_expression_expr
248+
OrderBy.as_mql_expr = order_by
278249
Query.as_mql = query
279250
RawSQL.as_mql = raw_sql
280251
Ref.as_mql = ref
252+
Ref.is_simple_column = ref_is_simple_column
281253
ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql
282254
Star.as_mql = star
283-
Subquery.as_mql = subquery
255+
Subquery.as_mql_expr = subquery
284256
When.as_mql = when
285257
Value.as_mql = value
286-
BaseExpression.is_simple_expression = _is_simple_expression
287-
BaseExpression.is_simple_column = _is_simple_column
288-
BaseExpression.is_constant_value = _is_constant_value

django_mongodb_backend/expressions/search.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -933,12 +933,15 @@ def __str__(self):
933933
def __repr__(self):
934934
return f"SearchText({self.lhs}, {self.rhs})"
935935

936-
def as_mql(self, compiler, connection, as_path=False):
937-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
938-
value = process_rhs(self, compiler, connection, as_path=as_path)
939-
if as_path:
940-
return {lhs_mql: {"$gte": value}}
941-
return {"$expr": {"$gte": [lhs_mql, value]}}
936+
def as_mql_expr(self, compiler, connection):
937+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
938+
value = process_rhs(self, compiler, connection, as_path=False)
939+
return {"$gte": [lhs_mql, value]}
940+
941+
def as_mql_path(self, compiler, connection):
942+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
943+
value = process_rhs(self, compiler, connection, as_path=True)
944+
return {lhs_mql: {"$gte": value}}
942945

943946

944947
CharField.register_lookup(SearchTextLookup)

0 commit comments

Comments
 (0)