Skip to content

Commit d11378a

Browse files
committed
WIP
1 parent 141f1cf commit d11378a

File tree

8 files changed

+123
-84
lines changed

8 files changed

+123
-84
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001
7777
return f"{prefix}{self.target.column}"
7878

7979

80+
@property
81+
def col_is_simple_column(self):
82+
return self.alias is not None
83+
84+
8085
def col_pairs(self, compiler, connection, as_path=False):
8186
cols = self.get_cols()
8287
if len(cols) > 1:
@@ -170,6 +175,11 @@ def ref(self, compiler, connection, as_path=False): # noqa: ARG001
170175
return f"{prefix}{refs}"
171176

172177

178+
@property
179+
def ref_is_simple_column(self):
180+
return isinstance(self.source, Col) and self.source.alias is not None
181+
182+
173183
def star(self, compiler, connection, as_path=False): # noqa: ARG001
174184
return {"$literal": True}
175185

@@ -229,8 +239,11 @@ def base_expression(self, compiler, connection, as_path=False, **extra):
229239

230240

231241
def register_expressions():
242+
BaseExpression.as_mql = base_expression
243+
BaseExpression.is_simple_column = False
232244
Case.as_mql_expr = case
233245
Col.as_mql = col
246+
Col.is_simple_column = col_is_simple_column
234247
ColPairs.as_mql = col_pairs
235248
CombinedExpression.as_mql_expr = combined_expression
236249
Exists.as_mql_expr = exists
@@ -241,9 +254,9 @@ def register_expressions():
241254
Query.as_mql = query
242255
RawSQL.as_mql = raw_sql
243256
Ref.as_mql = ref
257+
Ref.is_simple_column = ref_is_simple_column
244258
ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql
245259
Star.as_mql = star
246260
Subquery.as_mql_expr = subquery
247261
When.as_mql = when
248262
Value.as_mql = value
249-
BaseExpression.as_mql = base_expression

django_mongodb_backend/expressions/search.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -933,12 +933,15 @@ def __str__(self):
933933
def __repr__(self):
934934
return f"SearchText({self.lhs}, {self.rhs})"
935935

936-
def as_mql(self, compiler, connection, as_path=False):
937-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
938-
value = process_rhs(self, compiler, connection, as_path=as_path)
939-
if as_path:
940-
return {lhs_mql: {"$gte": value}}
941-
return {"$expr": {"$gte": [lhs_mql, value]}}
936+
def as_mql_expr(self, compiler, connection):
937+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
938+
value = process_rhs(self, compiler, connection, as_path=False)
939+
return {"$gte": [lhs_mql, value]}
940+
941+
def as_mql_path(self, compiler, connection):
942+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
943+
value = process_rhs(self, compiler, connection, as_path=True)
944+
return {lhs_mql: {"$gte": value}}
942945

943946

944947
CharField.register_lookup(SearchTextLookup)

django_mongodb_backend/fields/array.py

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -230,9 +230,9 @@ def formfield(self, **kwargs):
230230

231231

232232
class Array(Func):
233-
def as_mql(self, compiler, connection, as_path=False):
233+
def as_mql_expr(self, compiler, connection):
234234
return [
235-
expr.as_mql(compiler, connection, as_path=as_path)
235+
expr.as_mql(compiler, connection, as_path=False)
236236
for expr in self.get_source_expressions()
237237
]
238238

@@ -254,24 +254,16 @@ def __init__(self, lhs, rhs):
254254
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
255255
lookup_name = "contains"
256256

257-
def as_mql(self, compiler, connection, as_path=False):
258-
if as_path and self.is_simple_expression():
259-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
260-
value = process_rhs(self, compiler, connection, as_path=as_path)
261-
if value is None:
262-
return False
263-
return {lhs_mql: {"$all": value}}
257+
def as_mql_path(self, compiler, connection):
258+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
259+
value = process_rhs(self, compiler, connection, as_path=True)
260+
if value is None:
261+
return False
262+
return {lhs_mql: {"$all": value}}
263+
264+
def as_mql_expr(self, compiler, connection):
264265
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
265266
value = process_rhs(self, compiler, connection, as_path=False)
266-
expr = {
267-
"$and": [
268-
{"$ne": [lhs_mql, None]},
269-
{"$ne": [value, None]},
270-
{"$setIsSubset": [value, lhs_mql]},
271-
]
272-
}
273-
if as_path:
274-
return {"$expr": expr}
275267
return {
276268
"$and": [
277269
{"$ne": [lhs_mql, None]},
@@ -285,19 +277,16 @@ def as_mql(self, compiler, connection, as_path=False):
285277
class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
286278
lookup_name = "contained_by"
287279

288-
def as_mql(self, compiler, connection, as_path=False):
280+
def as_mql_expr(self, compiler, connection):
289281
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
290282
value = process_rhs(self, compiler, connection, as_path=False)
291-
expr = {
283+
return {
292284
"$and": [
293285
{"$ne": [lhs_mql, None]},
294286
{"$ne": [value, None]},
295287
{"$setIsSubset": [lhs_mql, value]},
296288
]
297289
}
298-
if as_path:
299-
return {"$expr": expr}
300-
return expr
301290

302291

303292
@ArrayField.register_lookup
@@ -344,36 +333,30 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
344333
},
345334
]
346335

347-
def as_mql(self, compiler, connection, as_path=False):
348-
if as_path and self.is_simple_expression():
349-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
350-
value = process_rhs(self, compiler, connection, as_path=True)
351-
return {lhs_mql: {"$in": value}}
336+
def as_mql_path(self, compiler, connection):
337+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
338+
value = process_rhs(self, compiler, connection, as_path=True)
339+
return {lhs_mql: {"$in": value}}
352340

341+
def as_mql_expr(self, compiler, connection):
353342
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
354343
value = process_rhs(self, compiler, connection, as_path=False)
355-
expr = {
344+
return {
356345
"$and": [
357346
{"$ne": [lhs_mql, None]},
358347
{"$size": {"$setIntersection": [value, lhs_mql]}},
359348
]
360349
}
361-
if as_path:
362-
return {"$expr": expr}
363-
return expr
364350

365351

366352
@ArrayField.register_lookup
367353
class ArrayLenTransform(Transform):
368354
lookup_name = "len"
369355
output_field = IntegerField()
370356

371-
def as_mql(self, compiler, connection, as_path=False):
357+
def as_mql_expr(self, compiler, connection, as_path=False):
372358
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
373-
expr = {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
374-
if as_path:
375-
return {"$expr": expr}
376-
return expr
359+
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
377360

378361

379362
@ArrayField.register_lookup
@@ -398,15 +381,20 @@ def __init__(self, index, base_field, *args, **kwargs):
398381
self.index = index
399382
self.base_field = base_field
400383

401-
def as_mql(self, compiler, connection, as_path=False):
402-
if as_path and self.is_simple_column(self.lhs):
403-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
404-
return f"{lhs_mql}.{self.index}"
384+
def is_simple_expression(self):
385+
return self.is_simple_column
386+
387+
@property
388+
def is_simple_column(self):
389+
return self.lhs.is_simple_column
390+
391+
def as_mql_path(self, compiler, connection):
392+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
393+
return f"{lhs_mql}.{self.index}"
394+
395+
def as_mql_expr(self, compiler, connection):
405396
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
406-
expr = {"$arrayElemAt": [lhs_mql, self.index]}
407-
if as_path:
408-
return {"$expr": expr}
409-
return expr
397+
return {"$arrayElemAt": [lhs_mql, self.index]}
410398

411399
@property
412400
def output_field(self):
@@ -428,7 +416,7 @@ def __init__(self, start, end, *args, **kwargs):
428416
self.start = start
429417
self.end = end
430418

431-
def as_mql(self, compiler, connection, as_path=False):
419+
def as_mql_expr(self, compiler, connection):
432420
lhs_mql = process_lhs(self, compiler, connection)
433421
return {"$slice": [lhs_mql, self.start, self.end]}
434422

django_mongodb_backend/fields/embedded_model.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.db import models
66
from django.db.models.fields.related import lazy_related_operation
77
from django.db.models.lookups import Transform
8+
from django.utils.functional import cached_property
89

910
from .. import forms
1011

@@ -165,6 +166,18 @@ def __init__(self, key_name, ref_field, *args, **kwargs):
165166
def get_lookup(self, name):
166167
return self.ref_field.get_lookup(name)
167168

169+
def is_simple_expression(self):
170+
return self.is_simple_column
171+
172+
@cached_property
173+
def is_simple_column(self):
174+
previous = self
175+
while isinstance(previous, KeyTransform):
176+
if not previous.key_name.isalnum():
177+
return False
178+
previous = previous.lhs
179+
return previous.is_simple_column
180+
168181
def get_transform(self, name):
169182
"""
170183
Validate that `name` is either a field of an embedded model or a
@@ -184,16 +197,22 @@ def get_transform(self, name):
184197
f"{suggestion}"
185198
)
186199

187-
def as_mql(self, compiler, connection, as_path=False):
200+
def as_mql_path(self, compiler, connection):
201+
previous = self
202+
key_transforms = []
203+
while isinstance(previous, KeyTransform):
204+
key_transforms.insert(0, previous.key_name)
205+
previous = previous.lhs
206+
mql = previous.as_mql(compiler, connection, as_path=True)
207+
mql_path = ".".join(key_transforms)
208+
return f"{mql}.{mql_path}"
209+
210+
def as_mql_expr(self, compiler, connection):
188211
previous = self
189212
key_transforms = []
190213
while isinstance(previous, KeyTransform):
191214
key_transforms.insert(0, previous.key_name)
192215
previous = previous.lhs
193-
if as_path:
194-
mql = previous.as_mql(compiler, connection, as_path=True)
195-
mql_path = ".".join(key_transforms)
196-
return f"{mql}.{mql_path}"
197216
mql = previous.as_mql(compiler, connection)
198217
for key in key_transforms:
199218
mql = {"$getField": {"input": mql, "field": key}}

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from django.db.models.expressions import Col
66
from django.db.models.fields.related import lazy_related_operation
77
from django.db.models.lookups import Lookup, Transform
8+
from django.utils.functional import cached_property
89

910
from .. import forms
1011
from ..query_utils import process_lhs, process_rhs
@@ -128,7 +129,7 @@ def process_rhs(self, compiler, connection, as_path=False):
128129
for v in value
129130
]
130131

131-
def as_mql(self, compiler, connection, as_path=False):
132+
def as_mql_expr(self, compiler, connection):
132133
# Querying a subfield within the array elements (via nested
133134
# KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over
134135
# the array and applying $in on the subfield.
@@ -138,7 +139,7 @@ def as_mql(self, compiler, connection, as_path=False):
138139
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators_expr[self.lookup_name](
139140
inner_lhs_mql, values
140141
)
141-
return {"$expr": {"$anyElementTrue": lhs_mql}}
142+
return {"$anyElementTrue": lhs_mql}
142143

143144

144145
@_EmbeddedModelArrayOutputField.register_lookup
@@ -236,13 +237,25 @@ def __init__(self, key_name, array_field, *args, **kwargs):
236237
column_name = f"$item.{key_name}"
237238
column_target.db_column = column_name
238239
column_target.set_attributes_from_name(column_name)
239-
self._lhs = Col(None, column_target)
240+
self._lhs = Col(False, column_target)
240241
self._sub_transform = None
241242

242243
def __call__(self, this, *args, **kwargs):
243244
self._lhs = self._sub_transform(self._lhs, *args, **kwargs)
244245
return self
245246

247+
def is_simple_expression(self):
248+
return self.is_simple_column
249+
250+
@cached_property
251+
def is_simple_column(self):
252+
previous = self
253+
while isinstance(previous, KeyTransform):
254+
if not previous.key_name.isalnum():
255+
return False
256+
previous = previous.lhs
257+
return previous.is_simple_column and self._lhs.is_simple_column
258+
246259
def get_lookup(self, name):
247260
return self.output_field.get_lookup(name)
248261

@@ -275,11 +288,12 @@ def get_transform(self, name):
275288
f"{suggestion}"
276289
)
277290

278-
def as_mql(self, compiler, connection, as_path=False):
279-
if as_path:
280-
inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True)
281-
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
282-
return f"{inner_lhs_mql}.{lhs_mql}"
291+
def as_mql_path(self, compiler, connection):
292+
inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True).removeprefix("$item.")
293+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
294+
return f"{lhs_mql}.{inner_lhs_mql}"
295+
296+
def as_mql_expr(self, compiler, connection):
283297
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
284298
lhs_mql = process_lhs(self, compiler, connection)
285299
return {

django_mongodb_backend/fields/json.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
)
1919

2020
from ..lookups import builtin_lookup_expr, builtin_lookup_path
21-
from ..query_utils import is_simple_column, is_simple_expression, process_lhs, process_rhs
21+
from ..query_utils import process_lhs, process_rhs
2222

2323

2424
def build_json_mql_path(lhs, key_transforms, as_path=False):
@@ -73,7 +73,8 @@ def _has_key_predicate(path, root_column=None, negated=False, as_path=False):
7373

7474

7575
def has_key_check_simple_expression(self):
76-
return is_simple_expression(self) and all("." not in v for v in self.rhs)
76+
rhs = [self.rhs] if not isinstance(self.rhs, (list, tuple)) else self.rhs
77+
return self.is_simple_column and all(key.isalnum() for key in rhs)
7778

7879

7980
def has_key_lookup(self, compiler, connection, as_path=False):
@@ -225,17 +226,27 @@ def key_transform_exact_expr(self, compiler, connection):
225226
return builtin_lookup_expr(self, compiler, connection)
226227

227228

229+
def keytransform_is_simple_column(self):
230+
previous = self
231+
while isinstance(previous, KeyTransform):
232+
if not previous.key_name.isalnum():
233+
return False
234+
previous = previous.lhs
235+
return previous.is_simple_column
236+
237+
228238
def register_json_field():
229239
ContainedBy.as_mql = contained_by
230240
DataContains.as_mql = data_contains
231241
HasAnyKeys.mongo_operator = "$or"
232242
HasKey.mongo_operator = None
233-
HasKeyLookup.is_simple_expression = has_key_check_simple_expression
234243
HasKeyLookup.as_mql_path = partialmethod(has_key_lookup, as_path=True)
235244
HasKeyLookup.as_mql_expr = partialmethod(has_key_lookup, as_path=False)
245+
HasKeyLookup.is_simple_expression = has_key_check_simple_expression
236246
HasKeys.mongo_operator = "$and"
237247
JSONExact.process_rhs = json_exact_process_rhs
238-
KeyTransform.is_simple_expression = is_simple_column
248+
KeyTransform.is_simple_column = property(keytransform_is_simple_column)
249+
KeyTransform.is_simple_expression = keytransform_is_simple_column
239250
KeyTransform.as_mql_path = partialmethod(key_transform, as_path=True)
240251
KeyTransform.as_mql_expr = partialmethod(key_transform, as_path=False)
241252
KeyTransformIn.as_mql_path = key_transform_in_path

0 commit comments

Comments
 (0)