Skip to content

Commit 2558faf

Browse files
committed
Support index definition on Embedded Models in top level model.
1 parent e227720 commit 2558faf

File tree

6 files changed

+238
-3
lines changed

6 files changed

+238
-3
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Value,
2626
When,
2727
)
28+
from django.db.models.indexes import IndexExpression
2829
from django.db.models.sql import Query
2930

3031
from django_mongodb_backend.query_utils import process_lhs
@@ -106,6 +107,19 @@ def expression_wrapper(self, compiler, connection):
106107
return self.expression.as_mql(compiler, connection, as_expr=True)
107108

108109

110+
def index_expression(self, compiler, connection, as_expr=False): # noqa: ARG001
111+
result = []
112+
for expr in self.get_source_expressions():
113+
if expr is None:
114+
continue
115+
for sub_expr in expr.get_source_expressions():
116+
try:
117+
result.append(sub_expr.as_mql(compiler, connection))
118+
except FullResultSet:
119+
result.append(Value(True).as_mql(compiler, connection))
120+
return result
121+
122+
109123
def negated_expression(self, compiler, connection):
110124
return {"$not": expression_wrapper(self, compiler, connection)}
111125

@@ -244,6 +258,7 @@ def register_expressions():
244258
Exists.as_mql_expr = exists
245259
ExpressionList.as_mql = process_lhs
246260
ExpressionWrapper.as_mql_expr = expression_wrapper
261+
IndexExpression.as_mql = index_expression
247262
NegatedExpression.as_mql_expr = negated_expression
248263
OrderBy.as_mql_expr = partialmethod(order_by, as_expr=True)
249264
OrderBy.as_mql_path = partialmethod(order_by, as_expr=False)

django_mongodb_backend/features.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures):
9999
"model_fields.test_jsonfield.TestSaveLoad.test_bulk_update_custom_get_prep_value",
100100
# To debug: https://github.com/mongodb/django-mongodb-backend/issues/362
101101
"constraints.tests.UniqueConstraintTests.test_validate_case_when",
102+
# Simple expression index are supported
103+
"schema.tests.SchemaTests.test_func_unique_constraint_unsupported",
104+
"schema.tests.SchemaTests.test_func_index_unsupported",
102105
}
103106
# $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3.
104107
_django_test_expected_failures_bitwise = {

django_mongodb_backend/fields/embedded_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ def as_mql_path(self, compiler, connection):
211211
def output_field(self):
212212
return self._field
213213

214+
def db_type(self, connection):
215+
return self.output_field.db_type(connection)
216+
214217
@property
215218
def can_use_path(self):
216219
return self.is_simple_column

django_mongodb_backend/indexes.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from django.core.checks import Error, Warning
55
from django.db import NotSupportedError
66
from django.db.models import FloatField, Index, IntegerField
7+
from django.db.models.expressions import OrderBy
78
from django.db.models.lookups import BuiltinLookup
89
from django.db.models.sql.query import Query
910
from django.db.models.sql.where import AND, XOR, WhereNode
@@ -46,10 +47,29 @@ def builtin_lookup_idx(self, compiler, connection):
4647

4748
def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False, column_prefix=""):
4849
"""Return a pymongo IndexModel for this Django Index."""
50+
filter_expression = defaultdict(dict)
51+
expressions_fields = []
4952
if self.contains_expressions:
50-
return None
53+
query = Query(model=model, alias_cols=False)
54+
compiler = query.get_compiler(connection=schema_editor.connection)
55+
for expression in self.expressions:
56+
field_ = expression.resolve_expression(query)
57+
column = field_.as_mql(compiler, schema_editor.connection)
58+
db_type = (
59+
field_.expression.db_type(schema_editor.connection)
60+
if isinstance(field_, OrderBy)
61+
else field_.output_field.db_type(schema_editor.connection)
62+
)
63+
if unique:
64+
filter_expression[column].update({"$type": db_type})
65+
order = (
66+
DESCENDING
67+
if isinstance(expression, OrderBy) and expression.descending
68+
else ASCENDING
69+
)
70+
expressions_fields.append((column, order))
71+
5172
kwargs = {}
52-
filter_expression = defaultdict(dict)
5373
if self.condition:
5474
filter_expression.update(self._get_condition_mql(model, schema_editor))
5575
if unique:
@@ -80,7 +100,7 @@ def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False
80100
for field_name, order in self.fields_orders
81101
]
82102
)
83-
return IndexModel(index_orders, name=self.name, **kwargs)
103+
return IndexModel(expressions_fields + index_orders, name=self.name, **kwargs)
84104

85105

86106
def where_node_idx(self, compiler, connection):

django_mongodb_backend/schema.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
44
from django.db.models import Index, UniqueConstraint
5+
from django.db.models.expressions import F, OrderBy
56
from pymongo.operations import SearchIndexModel
67

78
from django_mongodb_backend.indexes import SearchIndex
@@ -351,6 +352,35 @@ def _remove_field_index(self, model, field, column_prefix=""):
351352
)
352353
collection.drop_index(index_names[0])
353354

355+
def _check_expression_indexes_applicable(self, expressions):
356+
return all(
357+
isinstance(expression.expression if isinstance(expression, OrderBy) else expression, F)
358+
for expression in expressions
359+
)
360+
361+
def _unique_supported(
362+
self,
363+
condition=None,
364+
deferrable=None,
365+
include=None,
366+
expressions=None,
367+
nulls_distinct=None,
368+
):
369+
return (
370+
(not condition or self.connection.features.supports_partial_indexes)
371+
and (not deferrable or self.connection.features.supports_deferrable_unique_constraints)
372+
and (not include or self.connection.features.supports_covering_indexes)
373+
and (
374+
not expressions
375+
or self.connection.features.supports_expression_indexes
376+
or self._check_expression_indexes_applicable(expressions)
377+
)
378+
and (
379+
nulls_distinct is None
380+
or self.connection.features.supports_nulls_distinct_unique_constraints
381+
)
382+
)
383+
354384
@ignore_embedded_models
355385
def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None):
356386
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
@@ -361,6 +391,7 @@ def add_constraint(self, model, constraint, field=None, column_prefix="", parent
361391
nulls_distinct=constraint.nulls_distinct,
362392
):
363393
idx = Index(
394+
*constraint.expressions,
364395
fields=constraint.fields,
365396
name=constraint.name,
366397
condition=constraint.condition,
@@ -391,6 +422,7 @@ def remove_constraint(self, model, constraint):
391422
nulls_distinct=constraint.nulls_distinct,
392423
):
393424
idx = Index(
425+
*constraint.expressions,
394426
fields=constraint.fields,
395427
name=constraint.name,
396428
condition=constraint.condition,

tests/schema_/test_embedded_model.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22

33
from django.db import connection, models
4+
from django.db.models.expressions import F
45
from django.test import TransactionTestCase, skipUnlessDBFeature
56
from django.test.utils import isolate_apps
67

@@ -519,6 +520,167 @@ class Meta:
519520
self.assertTableNotExists(Author)
520521

521522

523+
class EmbeddedModelsTopLevelIndexTest(TestMixin, TransactionTestCase):
524+
@isolate_apps("schema_")
525+
def test_unique_together(self):
526+
"""Meta.unique_together defined at the top-level for embedded fields."""
527+
528+
class Address(EmbeddedModel):
529+
unique_together_one = models.CharField(max_length=10)
530+
unique_together_two = models.CharField(max_length=10)
531+
532+
class Meta:
533+
app_label = "schema_"
534+
535+
class Author(EmbeddedModel):
536+
address = EmbeddedModelField(Address)
537+
unique_together_three = models.CharField(max_length=10)
538+
unique_together_four = models.CharField(max_length=10)
539+
540+
class Meta:
541+
app_label = "schema_"
542+
543+
class Book(models.Model):
544+
author = EmbeddedModelField(Author)
545+
546+
class Meta:
547+
app_label = "schema_"
548+
constraints = [
549+
models.UniqueConstraint(
550+
F("author__unique_together_three").asc(),
551+
F("author__unique_together_four").desc(),
552+
name="unique_together_34",
553+
),
554+
(
555+
models.UniqueConstraint(
556+
F("author__address__unique_together_one"),
557+
F("author__address__unique_together_two").asc(),
558+
name="unique_together_12",
559+
)
560+
),
561+
]
562+
563+
with connection.schema_editor() as editor:
564+
editor.create_model(Book)
565+
self.assertTableExists(Book)
566+
# Embedded uniques are created from top-level definition.
567+
self.assertEqual(
568+
self.get_constraints_for_columns(
569+
Book, ["author.unique_together_three", "author.unique_together_four"]
570+
),
571+
["unique_together_34"],
572+
)
573+
self.assertEqual(
574+
self.get_constraints_for_columns(
575+
Book,
576+
["author.address.unique_together_one", "author.address.unique_together_two"],
577+
),
578+
["unique_together_12"],
579+
)
580+
editor.delete_model(Book)
581+
self.assertTableNotExists(Book)
582+
583+
@isolate_apps("schema_")
584+
def test_add_remove_field_indexes(self):
585+
"""AddField/RemoveField + EmbeddedModelField + Meta.indexes at top-level."""
586+
587+
class Address(EmbeddedModel):
588+
indexed_one = models.CharField(max_length=10)
589+
590+
class Meta:
591+
app_label = "schema_"
592+
593+
class Author(EmbeddedModel):
594+
address = EmbeddedModelField(Address)
595+
indexed_two = models.CharField(max_length=10)
596+
597+
class Meta:
598+
app_label = "schema_"
599+
600+
class Book(models.Model):
601+
author = EmbeddedModelField(Author)
602+
603+
class Meta:
604+
app_label = "schema_"
605+
indexes = [
606+
models.Index(F("author__indexed_two").asc(), name="indexed_two"),
607+
models.Index(F("author__address__indexed_one").asc(), name="indexed_one"),
608+
]
609+
610+
new_field = EmbeddedModelField(Author)
611+
new_field.set_attributes_from_name("author")
612+
613+
with connection.schema_editor() as editor:
614+
# Create the table and add the field.
615+
editor.create_model(Book)
616+
editor.add_field(Book, new_field)
617+
# Embedded indexes are created.
618+
self.assertEqual(
619+
self.get_constraints_for_columns(Book, ["author.indexed_two"]),
620+
["indexed_two"],
621+
)
622+
self.assertEqual(
623+
self.get_constraints_for_columns(
624+
Book,
625+
["author.address.indexed_one"],
626+
),
627+
["indexed_one"],
628+
)
629+
editor.delete_model(Book)
630+
self.assertTableNotExists(Book)
631+
632+
@isolate_apps("schema_")
633+
def test_add_remove_field_constraints(self):
634+
"""AddField/RemoveField + EmbeddedModelField + Meta.constraints at top-level."""
635+
636+
class Address(EmbeddedModel):
637+
unique_constraint_one = models.CharField(max_length=10)
638+
639+
class Meta:
640+
app_label = "schema_"
641+
642+
class Author(EmbeddedModel):
643+
address = EmbeddedModelField(Address)
644+
unique_constraint_two = models.CharField(max_length=10)
645+
646+
class Meta:
647+
app_label = "schema_"
648+
649+
class Book(models.Model):
650+
author = EmbeddedModelField(Author)
651+
652+
class Meta:
653+
app_label = "schema_"
654+
constraints = [
655+
models.UniqueConstraint(F("author__unique_constraint_two"), name="unique_two"),
656+
models.UniqueConstraint(
657+
F("author__address__unique_constraint_one"), name="unique_one"
658+
),
659+
]
660+
661+
new_field = EmbeddedModelField(Author)
662+
new_field.set_attributes_from_name("author")
663+
664+
with connection.schema_editor() as editor:
665+
# Create the table and add the field.
666+
editor.create_model(Book)
667+
editor.add_field(Book, new_field)
668+
# Embedded constraints are created.
669+
self.assertEqual(
670+
self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]),
671+
["unique_two"],
672+
)
673+
self.assertEqual(
674+
self.get_constraints_for_columns(
675+
Book,
676+
["author.address.unique_constraint_one"],
677+
),
678+
["unique_one"],
679+
)
680+
editor.delete_model(Book)
681+
self.assertTableNotExists(Book)
682+
683+
522684
class EmbeddedModelsIgnoredTests(TestMixin, TransactionTestCase):
523685
def test_embedded_not_created(self):
524686
"""create_model() and delete_model() ignore EmbeddedModel."""

0 commit comments

Comments
 (0)