Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SearchIndex and VectorSearchIndex #264

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
32a8f9c
Refactor index creation
WaVEV Feb 16, 2025
16bee55
stubed up method
WaVEV Feb 22, 2025
65c05ff
Atlas index creation
WaVEV Feb 24, 2025
c175fd6
Test
WaVEV Feb 24, 2025
a6cf7ec
Fix mapping typing issue
WaVEV Feb 25, 2025
7dfe00a
Refactor unit test
WaVEV Feb 26, 2025
114b77d
Add test with data
WaVEV Feb 26, 2025
37b1c76
Add atlas vector search index
WaVEV Mar 3, 2025
1cac435
Change minimun pymongo version
WaVEV Mar 4, 2025
3b80117
ObjectId Fields does not support integers
WaVEV Mar 4, 2025
6fee641
Refactor: remove singledispatchmethod
WaVEV Mar 4, 2025
b6ab85b
Add directConnection in testing settings
WaVEV Mar 7, 2025
8548fb0
Add lenght validator
WaVEV Mar 9, 2025
9b49ec1
Add unit test
WaVEV Mar 9, 2025
47cf5d0
Using fixed_size instead of size
WaVEV Mar 9, 2025
1ac46c2
Testing CI
WaVEV Mar 9, 2025
59d5ca0
[Testing CI] Set mongo image version tag
WaVEV Mar 9, 2025
a3a0cd6
Edits
WaVEV Mar 11, 2025
e3a745f
Testing CI
WaVEV Mar 11, 2025
4de2211
Remove test with data
WaVEV Mar 12, 2025
c17e934
Move mongo_data_types mapping
WaVEV Mar 12, 2025
bb7282e
Pumping mongo version to fix some unit test
WaVEV Mar 20, 2025
a26977d
Add validators test
WaVEV Mar 20, 2025
3771a36
Refactor fixed_size and size in array field
WaVEV Mar 21, 2025
ce94b7f
Add create_search_index to OperationDebugWrapper
WaVEV Mar 21, 2025
28ab176
change tuple for frozenset
WaVEV Mar 21, 2025
1ecaa90
Refactor
WaVEV Mar 21, 2025
1abcc96
Move import to the top
WaVEV Mar 21, 2025
78750ef
Add drop_search_index and list_search_indexes to OperationDebugWrapper
WaVEV Mar 22, 2025
6d8271c
rename size and fixed_size
WaVEV Mar 22, 2025
555d658
Docstring and search type mapping
WaVEV Mar 22, 2025
f8bed0d
Refactor if conditions
WaVEV Mar 22, 2025
e535f98
Add unit tests
WaVEV Mar 22, 2025
3a684d8
Add array field unit test
WaVEV Mar 22, 2025
9851732
Simplify with
WaVEV Mar 22, 2025
9ea7674
Remove blank line
WaVEV Mar 22, 2025
31b8bb1
Add deconstruct method
WaVEV Mar 24, 2025
459d4e0
Remove duplicate function
WaVEV Mar 26, 2025
22b0085
Update django_mongodb_backend/introspection.py
WaVEV Mar 27, 2025
288abce
Remove redundant tests
WaVEV Mar 27, 2025
a5580ba
Index prefix cannot be longer than 3 chars
WaVEV Mar 27, 2025
ab413ec
Check function in VectorSearchIndex.
WaVEV Mar 27, 2025
5271b4d
Add parameter connection to the vector search index check
WaVEV Mar 28, 2025
0a182fb
add system check and unit test
WaVEV Mar 28, 2025
1757203
Refactor imports
WaVEV Mar 28, 2025
c01fc08
Fix invalid similarity checker
WaVEV Mar 28, 2025
4696d09
Add similarities unit test
WaVEV Mar 28, 2025
15e3450
Remove check test, were moved to system check
WaVEV Mar 28, 2025
ddeb24c
Remove valitador get_pymongo_index_model.
WaVEV Apr 1, 2025
2989e4a
Add docstring
WaVEV Apr 1, 2025
de3b03b
remove validators unit test
WaVEV Apr 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/mongodb_settings.py
Original file line number Diff line number Diff line change
@@ -17,10 +17,12 @@
"default": {
"ENGINE": "django_mongodb_backend",
"NAME": "djangotests",
"OPTIONS": {"directConnection": True},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment explaining this parameter is necessary against local atlas instances running in docker.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docker site don't provide much info about it. Just give the connection uri format.
I think is to connect to a single MongoDB instance like standalone mode. But need to confirm that.

},
"other": {
"ENGINE": "django_mongodb_backend",
"NAME": "djangotests-other",
"OPTIONS": {"directConnection": True},
},
}

5 changes: 3 additions & 2 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
@@ -50,8 +50,9 @@ jobs:
- name: Copy the test runner file
run: cp .github/workflows/runtests.py django_repo/tests/runtests_.py
- name: Start MongoDB
uses: supercharge/mongodb-github-action@1.12.0
uses: wavev/mongodb-github-action@atlas
with:
mongodb-version: 6.0
mongodb-image: mongodb/mongodb-atlas-local
mongodb-version: 7.0.15
- name: Run tests
run: python3 django_repo/tests/runtests_.py
27 changes: 27 additions & 0 deletions django_mongodb_backend/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from itertools import chain

from django.apps import apps
from django.core.checks import Tags, register
from django.db import connections, router

from django_mongodb_backend.indexes import VectorSearchIndex


@register(Tags.models)
def check_vector_search_indexes(app_configs, databases=None, **kwargs): # noqa: ARG001
# Validate vector search indexes for models.
errors = []
if app_configs is None:
models = apps.get_models()
else:
models = chain.from_iterable(app_config.get_models() for app_config in app_configs)
for model in models:
for db in databases or ():
if not router.allow_migrate_model(db, model):
continue
connection = connections[db]
for model_index in model._meta.indexes:
if not isinstance(model_index, VectorSearchIndex):
continue
errors.extend(model_index.check(model, connection))
return errors
13 changes: 11 additions & 2 deletions django_mongodb_backend/fields/array.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from ..forms import SimpleArrayField
from ..query_utils import process_lhs, process_rhs
from ..utils import prefix_validation_error
from .validators import LengthValidator

__all__ = ["ArrayField"]

@@ -27,13 +28,21 @@ class ArrayField(CheckFieldDefaultMixin, Field):
}
_default_hint = ("list", "[]")

def __init__(self, base_field, size=None, **kwargs):
def __init__(self, base_field, max_size=None, size=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is just meant for vector_searching then can we have this be "fixed_size=True" rather than max_size? Then the validator would just use a ternary between ArrayMaxLengthValidator and LengthValidator. I think that may make understanding the difference easier to intuit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array is not only for vector search, is for general purpose. Like array of array. Or array of char field or array of embedded fields

self.base_field = base_field
self.max_size = max_size
self.size = size
if size and max_size:
raise ValueError("Cannot define both, size and max_size")
if self.max_size:
self.default_validators = [
*self.default_validators,
ArrayMaxLengthValidator(self.max_size),
]
if self.size:
self.default_validators = [
*self.default_validators,
ArrayMaxLengthValidator(self.size),
LengthValidator(self.size),
]
# For performance, only add a from_db_value() method if the base field
# implements it.
19 changes: 19 additions & 0 deletions django_mongodb_backend/fields/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from django.core.validators import BaseValidator
from django.utils.deconstruct import deconstructible
from django.utils.translation import ngettext_lazy


@deconstructible
class LengthValidator(BaseValidator):
message = ngettext_lazy(
"List contains %(show_value)d item, it should contain %(limit_value)d.",
"List contains %(show_value)d items, it should contain %(limit_value)d.",
"show_value",
)
code = "length"

def compare(self, a, b):
return a != b

def clean(self, x):
return len(x)
138 changes: 136 additions & 2 deletions django_mongodb_backend/indexes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import itertools
from collections import defaultdict

from django.core.checks import Error
from django.db import NotSupportedError
from django.db.models import Index
from django.db.models import DecimalField, FloatField, Index
from django.db.models.lookups import BuiltinLookup
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, XOR, WhereNode
from pymongo import ASCENDING, DESCENDING
from pymongo.operations import IndexModel
from pymongo.operations import IndexModel, SearchIndexModel

from django_mongodb_backend.fields import ArrayField

from .query_utils import process_rhs

@@ -101,6 +105,136 @@ def where_node_idx(self, compiler, connection):
return mql


class SearchIndex(Index):
suffix = "six"

# Maps Django internal type to atlas search index type.
# Reference: https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types
def search_index_data_types(self, field, db_type):
if field.get_internal_type() == "UUIDField":
return "uuid"
if field.get_internal_type() in ("ObjectIdAutoField", "ObjectIdField"):
return "ObjectId"
if field.get_internal_type() == "EmbeddedModelField":
return "embeddedDocuments"
if db_type in ("int", "long"):
return "number"
if db_type == "binData":
return "string"
if db_type == "bool":
return "boolean"
if db_type == "object":
return "document"
return db_type

def get_pymongo_index_model(
self, model, schema_editor, field=None, unique=False, column_prefix=""
):
fields = {}
for field_name, _ in self.fields_orders:
field_ = model._meta.get_field(field_name)
type_ = self.search_index_data_types(field_, field_.db_type(schema_editor.connection))
field_path = column_prefix + model._meta.get_field(field_name).column
fields[field_path] = {"type": type_}
return SearchIndexModel(
definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name
)


class VectorSearchIndex(SearchIndex):
suffix = "vsi"
ALLOWED_SIMILARITY_FUNCTIONS = frozenset(("euclidean", "cosine", "dotProduct"))

def __init__(self, *expressions, similarities="cosine", **kwargs):
super().__init__(*expressions, **kwargs)
# validate the similarities types
self.similarities = similarities

def check(self, model, connection):
errors = []
error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
similarities = (
self.similarities if isinstance(self.similarities, list) else [self.similarities]
)
for func in similarities:
if func not in self.ALLOWED_SIMILARITY_FUNCTIONS:
errors.append(
Error(
f"{func} isn't a valid similarity function, options "
f"are {', '.join(sorted(self.ALLOWED_SIMILARITY_FUNCTIONS))}",
obj=self,
id=f"{error_id_prefix}.E004",
)
)
for field_name, _ in self.fields_orders:
field_ = model._meta.get_field(field_name)
if isinstance(field_, ArrayField):
try:
int(field_.size)
except (ValueError, TypeError):
errors.append(
Error(
"Atlas vector search requires size.",
obj=self,
id=f"{error_id_prefix}.E001",
)
)
if not isinstance(field_.base_field, FloatField | DecimalField):
errors.append(
Error(
"Base type must be Float or Decimal.",
obj=self,
id=f"{error_id_prefix}.E002",
)
)
else:
field_type = field_.db_type(connection)
search_type = self.search_index_data_types(field_, field_type)
# filter - for fields that contain boolean, date, objectId,
# numeric, string, or UUID values. Reference:
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields
if search_type not in ("number", "string", "boolean", "objectId", "uuid", "date"):
errors.append(
Error(
f"Unsupported filter of type {field_.get_internal_type()}.",
obj=self,
id="django_mongodb_backend.indexes.VectorSearchIndex.E003",
)
)
return errors

def deconstruct(self):
path, args, kwargs = super().deconstruct()
kwargs["similarities"] = self.similarities
return path, args, kwargs

def get_pymongo_index_model(
self, model, schema_editor, field=None, unique=False, column_prefix=""
):
similarities = (
itertools.cycle([self.similarities])
if isinstance(self.similarities, str)
else iter(self.similarities)
)
fields = []
for field_name, _ in self.fields_orders:
field_ = model._meta.get_field(field_name)
field_path = column_prefix + model._meta.get_field(field_name).column
mappings = {"path": field_path}
if isinstance(field_, ArrayField):
mappings.update(
{
"type": "vector",
"numDimensions": int(field_.size),
"similarity": next(similarities),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure why similarities needs to be a list?
I saw earlier that it was one value per field, but I'm still not sure why we can't enforce just one index type as the similarity?

Another NIT: this mapping doesn't include optional quantization. I don't think will be necessary until we support BSONVector, but just wanted to note it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the code is a bit weird. Maybe it's not easy to read.
The idea here is that we can have a vector index with more than one vector, right? Like creating an index with two vectors—one using cosine similarity and the other using dot product. So the idea is that if you define multiple vector fields, you can assign them different similarity functions. If a single string is passed as the similarity function, it will be applied to all vectors.

}
)
else:
mappings["type"] = "filter"
fields.append(mappings)
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")


def register_indexes():
BuiltinLookup.as_mql_idx = builtin_lookup_idx
Index._get_condition_mql = _get_condition_mql
33 changes: 32 additions & 1 deletion django_mongodb_backend/introspection.py
Original file line number Diff line number Diff line change
@@ -2,14 +2,16 @@
from django.db.models import Index
from pymongo import ASCENDING, DESCENDING

from django_mongodb_backend.indexes import SearchIndex, VectorSearchIndex


class DatabaseIntrospection(BaseDatabaseIntrospection):
ORDER_DIR = {ASCENDING: "ASC", DESCENDING: "DESC"}

def table_names(self, cursor=None, include_views=False):
return sorted([x["name"] for x in self.connection.database.list_collections()])

def get_constraints(self, cursor, table_name):
def _get_index_info(self, table_name):
indexes = self.connection.get_collection(table_name).index_information()
constraints = {}
for name, details in indexes.items():
@@ -30,3 +32,32 @@ def get_constraints(self, cursor, table_name):
"options": {},
}
return constraints

def _get_search_index_info(self, table_name):
constraints = {}
indexes = self.connection.get_collection(table_name).list_search_indexes()
for details in indexes:
if details["type"] == "vectorSearch":
columns = [field["path"] for field in details["latestDefinition"]["fields"]]
type_ = VectorSearchIndex.suffix
options = details
else:
options = details["latestDefinition"]["mappings"]
columns = list(options.get("fields", {}).keys())
type_ = SearchIndex.suffix
constraints[details["name"]] = {
"check": False,
"columns": columns,
"definition": None,
"foreign_key": None,
"index": True,
"orders": [],
"primary_key": False,
"type": type_,
"unique": False,
"options": options,
}
return constraints

def get_constraints(self, cursor, table_name):
return {**self._get_index_info(table_name), **self._get_search_index_info(table_name)}
13 changes: 11 additions & 2 deletions django_mongodb_backend/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import Index, UniqueConstraint
from pymongo.operations import SearchIndexModel

from django_mongodb_backend.indexes import SearchIndex, VectorSearchIndex

from .fields import EmbeddedModelField
from .query import wrap_database_errors
@@ -265,7 +268,10 @@ def add_index(
)
if idx:
model = parent_model or model
self.get_collection(model._meta.db_table).create_indexes([idx])
if isinstance(idx, SearchIndexModel):
self.get_collection(model._meta.db_table).create_search_index(idx)
else:
self.get_collection(model._meta.db_table).create_indexes([idx])

def _add_composed_index(self, model, field_names, column_prefix="", parent_model=None):
"""Add an index on the given list of field_names."""
@@ -283,7 +289,10 @@ def _add_field_index(self, model, field, *, column_prefix=""):
def remove_index(self, model, index):
if index.contains_expressions:
return
self.get_collection(model._meta.db_table).drop_index(index.name)
if isinstance(index, SearchIndex | VectorSearchIndex):
self.get_collection(model._meta.db_table).drop_search_index(index.name)
else:
self.get_collection(model._meta.db_table).drop_index(index.name)

def _remove_composed_index(
self, model, field_names, constraint_kwargs, column_prefix="", parent_model=None
3 changes: 3 additions & 0 deletions django_mongodb_backend/utils.py
Original file line number Diff line number Diff line change
@@ -107,11 +107,14 @@ class OperationDebugWrapper:
"aggregate",
"create_collection",
"create_indexes",
"create_search_index",
"drop",
"index_information",
"insert_many",
"delete_many",
"drop_index",
"drop_search_index",
"list_search_indexes",
"rename",
"update_many",
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
django>=5.1,<5.2
pymongo>=4.6,<5.0
pymongo>=4.7,<5.0
14 changes: 14 additions & 0 deletions tests/indexes_/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
from django.db import models

from django_mongodb_backend.fields import ArrayField, EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel


class Data(EmbeddedModel):
integer = models.IntegerField()


class Article(models.Model):
headline = models.CharField(max_length=100)
number = models.IntegerField()
body = models.TextField()
data = models.JSONField()
embedded = EmbeddedModelField(Data)
auto_now = models.DateTimeField(auto_now=True)
title_embedded = ArrayField(models.FloatField(), size=10)
description_embedded = ArrayField(models.FloatField(), size=10)
number_list = ArrayField(models.FloatField())
name_list = ArrayField(models.CharField(max_length=30), size=10)
Loading