Skip to content

Commit 2f184a5

Browse files
authored
Add Index.get_pymongo_index_model() hook
1 parent 64b1c10 commit 2f184a5

File tree

2 files changed

+49
-40
lines changed

2 files changed

+49
-40
lines changed

django_mongodb_backend/indexes.py

+44
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from collections import defaultdict
2+
13
from django.db import NotSupportedError
24
from django.db.models import Index
35
from django.db.models.lookups import BuiltinLookup
46
from django.db.models.sql.query import Query
57
from django.db.models.sql.where import AND, XOR, WhereNode
8+
from pymongo import ASCENDING, DESCENDING
9+
from pymongo.operations import IndexModel
610

711
from .query_utils import process_rhs
812

@@ -36,6 +40,45 @@ def builtin_lookup_idx(self, compiler, connection):
3640
return {lhs_mql: {operator: value}}
3741

3842

43+
def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False, column_prefix=""):
44+
"""Return a pymongo IndexModel for this Django Index."""
45+
if self.contains_expressions:
46+
return None
47+
kwargs = {}
48+
filter_expression = defaultdict(dict)
49+
if self.condition:
50+
filter_expression.update(self._get_condition_mql(model, schema_editor))
51+
if unique:
52+
kwargs["unique"] = True
53+
# Indexing on $type matches the value of most SQL databases by
54+
# allowing multiple null values for the unique constraint.
55+
if field:
56+
column = column_prefix + field.column
57+
filter_expression[column].update({"$type": field.db_type(schema_editor.connection)})
58+
else:
59+
for field_name, _ in self.fields_orders:
60+
field_ = model._meta.get_field(field_name)
61+
filter_expression[field_.column].update(
62+
{"$type": field_.db_type(schema_editor.connection)}
63+
)
64+
if filter_expression:
65+
kwargs["partialFilterExpression"] = filter_expression
66+
index_orders = (
67+
[(column_prefix + field.column, ASCENDING)]
68+
if field
69+
else [
70+
# order is "" if ASCENDING or "DESC" if DESCENDING (see
71+
# django.db.models.indexes.Index.fields_orders).
72+
(
73+
column_prefix + model._meta.get_field(field_name).column,
74+
ASCENDING if order == "" else DESCENDING,
75+
)
76+
for field_name, order in self.fields_orders
77+
]
78+
)
79+
return IndexModel(index_orders, name=self.name, **kwargs)
80+
81+
3982
def where_node_idx(self, compiler, connection):
4083
if self.connector == AND:
4184
operator = "$and"
@@ -61,4 +104,5 @@ def where_node_idx(self, compiler, connection):
61104
def register_indexes():
62105
BuiltinLookup.as_mql_idx = builtin_lookup_idx
63106
Index._get_condition_mql = _get_condition_mql
107+
Index.get_pymongo_index_model = get_pymongo_index_model
64108
WhereNode.as_mql_idx = where_node_idx

django_mongodb_backend/schema.py

+5-40
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
from collections import defaultdict
2-
31
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
42
from django.db.models import Index, UniqueConstraint
5-
from pymongo import ASCENDING, DESCENDING
6-
from pymongo.operations import IndexModel
73

84
from .fields import EmbeddedModelField
95
from .query import wrap_database_errors
@@ -264,43 +260,12 @@ def alter_unique_together(
264260
def add_index(
265261
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
266262
):
267-
if index.contains_expressions:
268-
return
269-
kwargs = {}
270-
filter_expression = defaultdict(dict)
271-
if index.condition:
272-
filter_expression.update(index._get_condition_mql(model, self))
273-
if unique:
274-
kwargs["unique"] = True
275-
# Indexing on $type matches the value of most SQL databases by
276-
# allowing multiple null values for the unique constraint.
277-
if field:
278-
column = column_prefix + field.column
279-
filter_expression[column].update({"$type": field.db_type(self.connection)})
280-
else:
281-
for field_name, _ in index.fields_orders:
282-
field_ = model._meta.get_field(field_name)
283-
filter_expression[field_.column].update(
284-
{"$type": field_.db_type(self.connection)}
285-
)
286-
if filter_expression:
287-
kwargs["partialFilterExpression"] = filter_expression
288-
index_orders = (
289-
[(column_prefix + field.column, ASCENDING)]
290-
if field
291-
else [
292-
# order is "" if ASCENDING or "DESC" if DESCENDING (see
293-
# django.db.models.indexes.Index.fields_orders).
294-
(
295-
column_prefix + model._meta.get_field(field_name).column,
296-
ASCENDING if order == "" else DESCENDING,
297-
)
298-
for field_name, order in index.fields_orders
299-
]
263+
idx = index.get_pymongo_index_model(
264+
model, schema_editor=self, field=field, unique=unique, column_prefix=column_prefix
300265
)
301-
idx = IndexModel(index_orders, name=index.name, **kwargs)
302-
model = parent_model or model
303-
self.get_collection(model._meta.db_table).create_indexes([idx])
266+
if idx:
267+
model = parent_model or model
268+
self.get_collection(model._meta.db_table).create_indexes([idx])
304269

305270
def _add_composed_index(self, model, field_names, column_prefix="", parent_model=None):
306271
"""Add an index on the given list of field_names."""

0 commit comments

Comments
 (0)