diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 60400d940..bddcbd1e7 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -25,6 +25,7 @@ Value, When, ) +from django.db.models.indexes import IndexExpression from django.db.models.sql import Query from django_mongodb_backend.query_utils import process_lhs @@ -106,6 +107,19 @@ def expression_wrapper(self, compiler, connection): return self.expression.as_mql(compiler, connection, as_expr=True) +def index_expression(self, compiler, connection, as_expr=False): # noqa: ARG001 + result = [] + for expr in self.get_source_expressions(): + if expr is None: + continue + for sub_expr in expr.get_source_expressions(): + try: + result.append(sub_expr.as_mql(compiler, connection)) + except FullResultSet: + result.append(Value(True).as_mql(compiler, connection)) + return result + + def negated_expression(self, compiler, connection): return {"$not": expression_wrapper(self, compiler, connection)} @@ -244,6 +258,7 @@ def register_expressions(): Exists.as_mql_expr = exists ExpressionList.as_mql = process_lhs ExpressionWrapper.as_mql_expr = expression_wrapper + IndexExpression.as_mql = index_expression NegatedExpression.as_mql_expr = negated_expression OrderBy.as_mql_expr = partialmethod(order_by, as_expr=True) OrderBy.as_mql_path = partialmethod(order_by, as_expr=False) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 18a048bf6..12087cf7c 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -99,6 +99,9 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures): "model_fields.test_jsonfield.TestSaveLoad.test_bulk_update_custom_get_prep_value", # To debug: https://github.com/mongodb/django-mongodb-backend/issues/362 "constraints.tests.UniqueConstraintTests.test_validate_case_when", + # Simple expression index are supported + "schema.tests.SchemaTests.test_func_unique_constraint_unsupported", + "schema.tests.SchemaTests.test_func_index_unsupported", } # $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3. _django_test_expected_failures_bitwise = { diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 8c712596a..53a22b157 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -211,6 +211,9 @@ def as_mql_path(self, compiler, connection): def output_field(self): return self._field + def db_type(self, connection): + return self.output_field.db_type(connection) + @property def can_use_path(self): return self.is_simple_column diff --git a/django_mongodb_backend/indexes.py b/django_mongodb_backend/indexes.py index 806d2c70a..9b79c6dde 100644 --- a/django_mongodb_backend/indexes.py +++ b/django_mongodb_backend/indexes.py @@ -4,6 +4,7 @@ from django.core.checks import Error, Warning from django.db import NotSupportedError from django.db.models import FloatField, Index, IntegerField +from django.db.models.expressions import OrderBy from django.db.models.lookups import BuiltinLookup from django.db.models.sql.query import Query from django.db.models.sql.where import AND, XOR, WhereNode @@ -46,10 +47,29 @@ def builtin_lookup_idx(self, compiler, connection): def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False, column_prefix=""): """Return a pymongo IndexModel for this Django Index.""" + filter_expression = defaultdict(dict) + expressions_fields = [] if self.contains_expressions: - return None + query = Query(model=model, alias_cols=False) + compiler = query.get_compiler(connection=schema_editor.connection) + for expression in self.expressions: + field_ = expression.resolve_expression(query) + column = field_.as_mql(compiler, schema_editor.connection) + db_type = ( + field_.expression.db_type(schema_editor.connection) + if isinstance(field_, OrderBy) + else field_.output_field.db_type(schema_editor.connection) + ) + if unique: + filter_expression[column].update({"$type": db_type}) + order = ( + DESCENDING + if isinstance(expression, OrderBy) and expression.descending + else ASCENDING + ) + expressions_fields.append((column, order)) + kwargs = {} - filter_expression = defaultdict(dict) if self.condition: filter_expression.update(self._get_condition_mql(model, schema_editor)) if unique: @@ -80,7 +100,7 @@ def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False for field_name, order in self.fields_orders ] ) - return IndexModel(index_orders, name=self.name, **kwargs) + return IndexModel(expressions_fields + index_orders, name=self.name, **kwargs) def where_node_idx(self, compiler, connection): diff --git a/django_mongodb_backend/schema.py b/django_mongodb_backend/schema.py index 9bcaecc63..7abcb0f7a 100644 --- a/django_mongodb_backend/schema.py +++ b/django_mongodb_backend/schema.py @@ -2,6 +2,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.models import Index, UniqueConstraint +from django.db.models.expressions import F, OrderBy from pymongo.operations import SearchIndexModel from django_mongodb_backend.indexes import SearchIndex @@ -351,6 +352,35 @@ def _remove_field_index(self, model, field, column_prefix=""): ) collection.drop_index(index_names[0]) + def _check_expression_indexes_applicable(self, expressions): + return all( + isinstance(expression.expression if isinstance(expression, OrderBy) else expression, F) + for expression in expressions + ) + + def _unique_supported( + self, + condition=None, + deferrable=None, + include=None, + expressions=None, + nulls_distinct=None, + ): + return ( + (not condition or self.connection.features.supports_partial_indexes) + and (not deferrable or self.connection.features.supports_deferrable_unique_constraints) + and (not include or self.connection.features.supports_covering_indexes) + and ( + not expressions + or self.connection.features.supports_expression_indexes + or self._check_expression_indexes_applicable(expressions) + ) + and ( + nulls_distinct is None + or self.connection.features.supports_nulls_distinct_unique_constraints + ) + ) + @ignore_embedded_models def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None): if isinstance(constraint, UniqueConstraint) and self._unique_supported( @@ -361,6 +391,7 @@ def add_constraint(self, model, constraint, field=None, column_prefix="", parent nulls_distinct=constraint.nulls_distinct, ): idx = Index( + *constraint.expressions, fields=constraint.fields, name=constraint.name, condition=constraint.condition, @@ -391,6 +422,7 @@ def remove_constraint(self, model, constraint): nulls_distinct=constraint.nulls_distinct, ): idx = Index( + *constraint.expressions, fields=constraint.fields, name=constraint.name, condition=constraint.condition, diff --git a/docs/releases/5.2.x.rst b/docs/releases/5.2.x.rst index bed6a6cac..f02a79c9c 100644 --- a/docs/releases/5.2.x.rst +++ b/docs/releases/5.2.x.rst @@ -10,7 +10,11 @@ Django MongoDB Backend 5.2.x New features ------------ -- ... +- Added support for creating indexes from expressions. + Currently, only ``F()`` expressions are supported to reference top-level + model fields inside embedded models. + + Bug fixes --------- diff --git a/docs/topics/index.rst b/docs/topics/index.rst index 6e06b8125..a8b796a0f 100644 --- a/docs/topics/index.rst +++ b/docs/topics/index.rst @@ -11,3 +11,4 @@ know: embedded-models transactions known-issues + indexes diff --git a/docs/topics/indexes.rst b/docs/topics/indexes.rst new file mode 100644 index 000000000..d297039ca --- /dev/null +++ b/docs/topics/indexes.rst @@ -0,0 +1,22 @@ +Indexes from Expressions +======================== + +Django MongoDB Backend now supports creating indexes from expressions. +Currently, only ``F()`` expressions are supported, which allows referencing +fields from the top-level model inside embedded fields. + +Example:: + + from django.db import models + from django.db.models import F + + class Author(models.EmbeddedModel): + name = models.CharField() + + class Book(models.Model): + author = models.EmbeddedField(Author) + + class Meta: + indexes = [ + models.Index(F("author__name")), + ] diff --git a/tests/schema_/test_embedded_model.py b/tests/schema_/test_embedded_model.py index c6c926031..5497b5bef 100644 --- a/tests/schema_/test_embedded_model.py +++ b/tests/schema_/test_embedded_model.py @@ -1,6 +1,7 @@ import itertools from django.db import connection, models +from django.db.models.expressions import F from django.test import TransactionTestCase, skipUnlessDBFeature from django.test.utils import isolate_apps @@ -519,6 +520,167 @@ class Meta: self.assertTableNotExists(Author) +class EmbeddedModelsTopLevelIndexTest(TestMixin, TransactionTestCase): + @isolate_apps("schema_") + def test_unique_together(self): + """Meta.unique_together defined at the top-level for embedded fields.""" + + class Address(EmbeddedModel): + unique_together_one = models.CharField(max_length=10) + unique_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Author(EmbeddedModel): + address = EmbeddedModelField(Address) + unique_together_three = models.CharField(max_length=10) + unique_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint( + F("author__unique_together_three").asc(), + F("author__unique_together_four").desc(), + name="unique_together_34", + ), + ( + models.UniqueConstraint( + F("author__address__unique_together_one"), + F("author__address__unique_together_two").asc(), + name="unique_together_12", + ) + ), + ] + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created from top-level definition. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + ["unique_together_34"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + ["unique_together_12"], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_add_remove_field_indexes(self): + """AddField/RemoveField + EmbeddedModelField + Meta.indexes at top-level.""" + + class Address(EmbeddedModel): + indexed_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Author(EmbeddedModel): + address = EmbeddedModelField(Address) + indexed_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + indexes = [ + models.Index(F("author__indexed_two").asc(), name="indexed_two"), + models.Index(F("author__address__indexed_one").asc(), name="indexed_one"), + ] + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.indexed_two"]), + ["indexed_two"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.indexed_one"], + ), + ["indexed_one"], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_add_remove_field_constraints(self): + """AddField/RemoveField + EmbeddedModelField + Meta.constraints at top-level.""" + + class Address(EmbeddedModel): + unique_constraint_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Author(EmbeddedModel): + address = EmbeddedModelField(Address) + unique_constraint_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(F("author__unique_constraint_two"), name="unique_two"), + models.UniqueConstraint( + F("author__address__unique_constraint_one"), name="unique_one" + ), + ] + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded constraints are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]), + ["unique_two"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_constraint_one"], + ), + ["unique_one"], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + class EmbeddedModelsIgnoredTests(TestMixin, TransactionTestCase): def test_embedded_not_created(self): """create_model() and delete_model() ignore EmbeddedModel."""