diff --git a/django_mongodb_backend/indexes.py b/django_mongodb_backend/indexes.py index 567cd407..8ddceda7 100644 --- a/django_mongodb_backend/indexes.py +++ b/django_mongodb_backend/indexes.py @@ -1,8 +1,12 @@ +from collections import defaultdict + from django.db import NotSupportedError from django.db.models import Index from django.db.models.lookups import BuiltinLookup from django.db.models.sql.query import Query from django.db.models.sql.where import AND, XOR, WhereNode +from pymongo import ASCENDING, DESCENDING +from pymongo.operations import IndexModel from .query_utils import process_rhs @@ -36,6 +40,45 @@ def builtin_lookup_idx(self, compiler, connection): return {lhs_mql: {operator: value}} +def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False, column_prefix=""): + """Return a pymongo IndexModel for this Django Index.""" + if self.contains_expressions: + return None + kwargs = {} + filter_expression = defaultdict(dict) + if self.condition: + filter_expression.update(self._get_condition_mql(model, schema_editor)) + if unique: + kwargs["unique"] = True + # Indexing on $type matches the value of most SQL databases by + # allowing multiple null values for the unique constraint. + if field: + column = column_prefix + field.column + filter_expression[column].update({"$type": field.db_type(schema_editor.connection)}) + else: + for field_name, _ in self.fields_orders: + field_ = model._meta.get_field(field_name) + filter_expression[field_.column].update( + {"$type": field_.db_type(schema_editor.connection)} + ) + if filter_expression: + kwargs["partialFilterExpression"] = filter_expression + index_orders = ( + [(column_prefix + field.column, ASCENDING)] + if field + else [ + # order is "" if ASCENDING or "DESC" if DESCENDING (see + # django.db.models.indexes.Index.fields_orders). + ( + column_prefix + model._meta.get_field(field_name).column, + ASCENDING if order == "" else DESCENDING, + ) + for field_name, order in self.fields_orders + ] + ) + return IndexModel(index_orders, name=self.name, **kwargs) + + def where_node_idx(self, compiler, connection): if self.connector == AND: operator = "$and" @@ -61,4 +104,5 @@ def where_node_idx(self, compiler, connection): def register_indexes(): BuiltinLookup.as_mql_idx = builtin_lookup_idx Index._get_condition_mql = _get_condition_mql + Index.get_pymongo_index_model = get_pymongo_index_model WhereNode.as_mql_idx = where_node_idx diff --git a/django_mongodb_backend/schema.py b/django_mongodb_backend/schema.py index cea89ae2..8ae60918 100644 --- a/django_mongodb_backend/schema.py +++ b/django_mongodb_backend/schema.py @@ -1,9 +1,5 @@ -from collections import defaultdict - from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.models import Index, UniqueConstraint -from pymongo import ASCENDING, DESCENDING -from pymongo.operations import IndexModel from .fields import EmbeddedModelField from .query import wrap_database_errors @@ -264,43 +260,12 @@ def alter_unique_together( def add_index( self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None ): - if index.contains_expressions: - return - kwargs = {} - filter_expression = defaultdict(dict) - if index.condition: - filter_expression.update(index._get_condition_mql(model, self)) - if unique: - kwargs["unique"] = True - # Indexing on $type matches the value of most SQL databases by - # allowing multiple null values for the unique constraint. - if field: - column = column_prefix + field.column - filter_expression[column].update({"$type": field.db_type(self.connection)}) - else: - for field_name, _ in index.fields_orders: - field_ = model._meta.get_field(field_name) - filter_expression[field_.column].update( - {"$type": field_.db_type(self.connection)} - ) - if filter_expression: - kwargs["partialFilterExpression"] = filter_expression - index_orders = ( - [(column_prefix + field.column, ASCENDING)] - if field - else [ - # order is "" if ASCENDING or "DESC" if DESCENDING (see - # django.db.models.indexes.Index.fields_orders). - ( - column_prefix + model._meta.get_field(field_name).column, - ASCENDING if order == "" else DESCENDING, - ) - for field_name, order in index.fields_orders - ] + idx = index.get_pymongo_index_model( + model, schema_editor=self, field=field, unique=unique, column_prefix=column_prefix ) - idx = IndexModel(index_orders, name=index.name, **kwargs) - model = parent_model or model - self.get_collection(model._meta.db_table).create_indexes([idx]) + if idx: + model = parent_model or model + self.get_collection(model._meta.db_table).create_indexes([idx]) def _add_composed_index(self, model, field_names, column_prefix="", parent_model=None): """Add an index on the given list of field_names."""