Skip to content

Add SearchIndex and VectorSearchIndex #264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/mongodb_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
"default": {
"ENGINE": "django_mongodb_backend",
"NAME": "djangotests",
# Required when connecting to the Atlas image in Docker.
"OPTIONS": {"directConnection": True},
},
"other": {
"ENGINE": "django_mongodb_backend",
"NAME": "djangotests-other",
"OPTIONS": {"directConnection": True},
},
}

Expand Down
24 changes: 24 additions & 0 deletions .github/workflows/start_local_atlas.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
set -eu

echo "Starting the container"

IMAGE=${1:-mongodb/mongodb-atlas-local:latest}
DOCKER=$(which docker || which podman)

$DOCKER pull $IMAGE

$DOCKER kill mongodb_atlas_local || true

CONTAINER_ID=$($DOCKER run --rm -d --name mongodb_atlas_local -p 27017:27017 $IMAGE)

function wait() {
CONTAINER_ID=$1
echo "waiting for container to become healthy..."
$DOCKER logs mongodb_atlas_local
}

wait "$CONTAINER_ID"

# Sleep for a bit to let all services start.
sleep 5
56 changes: 56 additions & 0 deletions .github/workflows/test-python-atlas.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: Python Tests on Atlas

on:
pull_request:
paths:
- '**.py'
- '!setup.py'
- '.github/workflows/test-python-atlas.yml'
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

defaults:
run:
shell: bash -eux {0}

jobs:
build:
name: Django Test Suite
runs-on: ubuntu-latest
steps:
- name: Checkout django-mongodb-backend
uses: actions/checkout@v4
with:
persist-credentials: false
- name: install django-mongodb-backend
run: |
pip3 install --upgrade pip
pip3 install -e .
- name: Checkout Django
uses: actions/checkout@v4
with:
repository: 'mongodb-forks/django'
ref: 'mongodb-5.2.x'
path: 'django_repo'
persist-credentials: false
- name: Install system packages for Django's Python test dependencies
run: |
sudo apt-get update
sudo apt-get install libmemcached-dev
- name: Install Django and its Python test dependencies
run: |
cd django_repo/tests/
pip3 install -e ..
pip3 install -r requirements/py3.txt
- name: Copy the test settings file
run: cp .github/workflows/mongodb_settings.py django_repo/tests/
- name: Copy the test runner file
run: cp .github/workflows/runtests.py django_repo/tests/runtests_.py
- name: Start local Atlas
working-directory: .
run: bash .github/workflows/start_local_atlas.sh mongodb/mongodb-atlas-local:7
- name: Run tests
run: python3 django_repo/tests/runtests_.py
2 changes: 2 additions & 0 deletions django_mongodb_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
check_django_compatability()

from .aggregates import register_aggregates # noqa: E402
from .checks import register_checks # noqa: E402
from .expressions import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
Expand All @@ -17,6 +18,7 @@
__all__ = ["parse_uri"]

register_aggregates()
register_checks()
register_expressions()
register_fields()
register_functions()
Expand Down
32 changes: 32 additions & 0 deletions django_mongodb_backend/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from itertools import chain

from django.apps import apps
from django.core.checks import Tags, register
from django.db import connections, router


def check_indexes(app_configs, databases=None, **kwargs): # noqa: ARG001
"""
Call Index.check() on all model indexes.

This function will be obsolete when Django calls Index.check() after
https://code.djangoproject.com/ticket/36273.
"""
errors = []
if app_configs is None:
models = apps.get_models()
else:
models = chain.from_iterable(app_config.get_models() for app_config in app_configs)
for model in models:
for db in databases or ():
if not router.allow_migrate_model(db, model):
continue
connection = connections[db]
for model_index in model._meta.indexes:
if hasattr(model_index, "check"):
errors.extend(model_index.check(model, connection))
return errors


def register_checks():
register(check_indexes, Tags.models)
17 changes: 17 additions & 0 deletions django_mongodb_backend/features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property
from pymongo.errors import OperationFailure


class DatabaseFeatures(BaseDatabaseFeatures):
Expand Down Expand Up @@ -548,3 +549,19 @@ def django_test_expected_failures(self):
@cached_property
def is_mongodb_6_3(self):
return self.connection.get_database_version() >= (6, 3)

@cached_property
def supports_atlas_search(self):
"""Does the server support Atlas search queries and search indexes?"""
try:
# An existing collection must be used on MongoDB 6, otherwise
# the operation will not error when unsupported.
self.connection.get_collection("django_migrations").list_search_indexes()
except OperationFailure:
# It would be best to check the error message or error code to
# avoid hiding some other exception, but the message/code varies
# across MongoDB versions. Example error message:
# "$listSearchIndexes stage is only allowed on MongoDB Atlas".
return False
else:
return True
3 changes: 3 additions & 0 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, embedded_model, *args, **kwargs):
self.embedded_model = embedded_model
super().__init__(*args, **kwargs)

def db_type(self, connection):
return "embeddedDocuments"

def check(self, **kwargs):
from ..models import EmbeddedModel

Expand Down
183 changes: 181 additions & 2 deletions django_mongodb_backend/indexes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import itertools
from collections import defaultdict

from django.core.checks import Error, Warning
from django.db import NotSupportedError
from django.db.models import Index
from django.db.models import FloatField, Index, IntegerField
from django.db.models.lookups import BuiltinLookup
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, XOR, WhereNode
from pymongo import ASCENDING, DESCENDING
from pymongo.operations import IndexModel
from pymongo.operations import IndexModel, SearchIndexModel

from django_mongodb_backend.fields import ArrayField

from .query_utils import process_rhs

Expand Down Expand Up @@ -101,6 +105,181 @@ def where_node_idx(self, compiler, connection):
return mql


class SearchIndex(Index):
suffix = "six"
_error_id_prefix = "django_mongodb_backend.indexes.SearchIndex"

def __init__(self, *, fields=(), name=None):
super().__init__(fields=fields, name=name)

def check(self, model, connection):
errors = []
if not connection.features.supports_atlas_search:
errors.append(
Warning(
f"This MongoDB server does not support {self.__class__.__name__}.",
hint=(
"The index won't be created. Use an Atlas-enabled version of MongoDB, "
"or silence this warning if you don't care about it."
),
obj=model,
id=f"{self._error_id_prefix}.W001",
)
)
return errors

def search_index_data_types(self, db_type):
"""
Map a model field's type to search index type.
https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types
"""
if db_type in {"double", "int", "long"}:
return "number"
if db_type == "binData":
return "string"
if db_type == "bool":
return "boolean"
if db_type == "object":
return "document"
if db_type == "array":
return "embeddedDocuments"
return db_type

def get_pymongo_index_model(
self, model, schema_editor, field=None, unique=False, column_prefix=""
):
if not schema_editor.connection.features.supports_atlas_search:
return None
fields = {}
for field_name, _ in self.fields_orders:
field = model._meta.get_field(field_name)
type_ = self.search_index_data_types(field.db_type(schema_editor.connection))
field_path = column_prefix + model._meta.get_field(field_name).column
fields[field_path] = {"type": type_}
return SearchIndexModel(
definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name
)


class VectorSearchIndex(SearchIndex):
suffix = "vsi"
_error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
VALID_FIELD_TYPES = frozenset(("boolean", "date", "number", "objectId", "string", "uuid"))
VALID_SIMILARITIES = frozenset(("cosine", "dotProduct", "euclidean"))

def __init__(self, *, fields=(), name=None, similarities):
super().__init__(fields=fields, name=name)
self.similarities = similarities
self._multiple_similarities = isinstance(similarities, tuple | list)
for func in similarities if self._multiple_similarities else (similarities,):
if func not in self.VALID_SIMILARITIES:
raise ValueError(
f"'{func}' isn't a valid similarity function "
f"({', '.join(sorted(self.VALID_SIMILARITIES))})."
)
seen_fields = set()
for field_name, _ in self.fields_orders:
if field_name in seen_fields:
raise ValueError(f"Field '{field_name}' is duplicated in fields.")
seen_fields.add(field_name)

def check(self, model, connection):
errors = super().check(model, connection)
num_arrayfields = 0
for field_name, _ in self.fields_orders:
field = model._meta.get_field(field_name)
if isinstance(field, ArrayField):
num_arrayfields += 1
try:
int(field.size)
except (ValueError, TypeError):
errors.append(
Error(
f"VectorSearchIndex requires 'size' on field '{field_name}'.",
obj=model,
id=f"{self._error_id_prefix}.E002",
)
)
if not isinstance(field.base_field, FloatField | IntegerField):
errors.append(
Error(
"VectorSearchIndex requires the base field of "
f"ArrayField '{field.name}' to be FloatField or "
"IntegerField but is "
f"{field.base_field.get_internal_type()}.",
obj=model,
id=f"{self._error_id_prefix}.E003",
)
)
else:
search_type = self.search_index_data_types(field.db_type(connection))
if search_type not in self.VALID_FIELD_TYPES:
errors.append(
Error(
"VectorSearchIndex does not support field "
f"'{field_name}' ({field.get_internal_type()}).",
obj=model,
id=f"{self._error_id_prefix}.E004",
hint=f"Allowed types are {', '.join(sorted(self.VALID_FIELD_TYPES))}.",
)
)
if self._multiple_similarities and num_arrayfields != len(self.similarities):
errors.append(
Error(
f"VectorSearchIndex requires the same number of similarities "
f"and vector fields; {model._meta.object_name} has "
f"{num_arrayfields} ArrayField(s) but similarities "
f"has {len(self.similarities)} element(s).",
obj=model,
id=f"{self._error_id_prefix}.E005",
)
)
if num_arrayfields == 0:
errors.append(
Error(
"VectorSearchIndex requires at least one ArrayField to " "store vector data.",
obj=model,
id=f"{self._error_id_prefix}.E006",
hint="If you want to perform search operations without vectors, "
"use SearchIndex instead.",
)
)
return errors

def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs["similarities"] = self.similarities
return path, args, kwargs

def get_pymongo_index_model(
self, model, schema_editor, field=None, unique=False, column_prefix=""
):
if not schema_editor.connection.features.supports_atlas_search:
return None
similarities = (
itertools.cycle([self.similarities])
if not self._multiple_similarities
else iter(self.similarities)
)
fields = []
for field_name, _ in self.fields_orders:
field_ = model._meta.get_field(field_name)
field_path = column_prefix + model._meta.get_field(field_name).column
mappings = {"path": field_path}
if isinstance(field_, ArrayField):
mappings.update(
{
"type": "vector",
"numDimensions": int(field_.size),
"similarity": next(similarities),
}
)
else:
mappings["type"] = "filter"
fields.append(mappings)
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")


def register_indexes():
BuiltinLookup.as_mql_idx = builtin_lookup_idx
Index._get_condition_mql = _get_condition_mql
Expand Down
Loading
Loading