Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Index.get_pymongo_index_model() hook #272

Merged
merged 4 commits into from
Mar 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions django_mongodb_backend/indexes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from collections import defaultdict

from django.db import NotSupportedError
from django.db.models import Index
from django.db.models.lookups import BuiltinLookup
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, XOR, WhereNode
from pymongo import ASCENDING, DESCENDING
from pymongo.operations import IndexModel

from .query_utils import process_rhs

Expand Down Expand Up @@ -36,6 +40,45 @@ def builtin_lookup_idx(self, compiler, connection):
return {lhs_mql: {operator: value}}


def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False, column_prefix=""):
"""Return a pymongo IndexModel for this Django Index."""
if self.contains_expressions:
return None
kwargs = {}
filter_expression = defaultdict(dict)
if self.condition:
filter_expression.update(self._get_condition_mql(model, schema_editor))
if unique:
kwargs["unique"] = True
# Indexing on $type matches the value of most SQL databases by
# allowing multiple null values for the unique constraint.
if field:
column = column_prefix + field.column
filter_expression[column].update({"$type": field.db_type(schema_editor.connection)})
else:
for field_name, _ in self.fields_orders:
field_ = model._meta.get_field(field_name)
filter_expression[field_.column].update(
{"$type": field_.db_type(schema_editor.connection)}
)
if filter_expression:
kwargs["partialFilterExpression"] = filter_expression
index_orders = (
[(column_prefix + field.column, ASCENDING)]
if field
else [
# order is "" if ASCENDING or "DESC" if DESCENDING (see
# django.db.models.indexes.Index.fields_orders).
(
column_prefix + model._meta.get_field(field_name).column,
ASCENDING if order == "" else DESCENDING,
)
for field_name, order in self.fields_orders
]
)
return IndexModel(index_orders, name=self.name, **kwargs)


def where_node_idx(self, compiler, connection):
if self.connector == AND:
operator = "$and"
Expand All @@ -61,4 +104,5 @@ def where_node_idx(self, compiler, connection):
def register_indexes():
BuiltinLookup.as_mql_idx = builtin_lookup_idx
Index._get_condition_mql = _get_condition_mql
Index.get_pymongo_index_model = get_pymongo_index_model
WhereNode.as_mql_idx = where_node_idx
45 changes: 5 additions & 40 deletions django_mongodb_backend/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from collections import defaultdict

from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import Index, UniqueConstraint
from pymongo import ASCENDING, DESCENDING
from pymongo.operations import IndexModel

from .fields import EmbeddedModelField
from .query import wrap_database_errors
Expand Down Expand Up @@ -264,43 +260,12 @@ def alter_unique_together(
def add_index(
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
):
if index.contains_expressions:
return
kwargs = {}
filter_expression = defaultdict(dict)
if index.condition:
filter_expression.update(index._get_condition_mql(model, self))
if unique:
kwargs["unique"] = True
# Indexing on $type matches the value of most SQL databases by
# allowing multiple null values for the unique constraint.
if field:
column = column_prefix + field.column
filter_expression[column].update({"$type": field.db_type(self.connection)})
else:
for field_name, _ in index.fields_orders:
field_ = model._meta.get_field(field_name)
filter_expression[field_.column].update(
{"$type": field_.db_type(self.connection)}
)
if filter_expression:
kwargs["partialFilterExpression"] = filter_expression
index_orders = (
[(column_prefix + field.column, ASCENDING)]
if field
else [
# order is "" if ASCENDING or "DESC" if DESCENDING (see
# django.db.models.indexes.Index.fields_orders).
(
column_prefix + model._meta.get_field(field_name).column,
ASCENDING if order == "" else DESCENDING,
)
for field_name, order in index.fields_orders
]
idx = index.get_pymongo_index_model(
model, schema_editor=self, field=field, unique=unique, column_prefix=column_prefix
)
idx = IndexModel(index_orders, name=index.name, **kwargs)
model = parent_model or model
self.get_collection(model._meta.db_table).create_indexes([idx])
if idx:
model = parent_model or model
self.get_collection(model._meta.db_table).create_indexes([idx])

def _add_composed_index(self, model, field_names, column_prefix="", parent_model=None):
"""Add an index on the given list of field_names."""
Expand Down