-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SearchIndex and VectorSearchIndex #264
base: main
Are you sure you want to change the base?
Changes from all commits
32a8f9c
16bee55
65c05ff
c175fd6
a6cf7ec
7dfe00a
114b77d
37b1c76
1cac435
3b80117
6fee641
b6ab85b
8548fb0
9b49ec1
47cf5d0
1ac46c2
59d5ca0
a3a0cd6
e3a745f
4de2211
c17e934
bb7282e
a26977d
3771a36
ce94b7f
28ab176
1ecaa90
1abcc96
78750ef
6d8271c
555d658
f8bed0d
e535f98
3a684d8
9851732
9ea7674
31b8bb1
459d4e0
22b0085
288abce
a5580ba
ab413ec
5271b4d
0a182fb
1757203
c01fc08
4696d09
15e3450
ddeb24c
2989e4a
de3b03b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from itertools import chain | ||
|
||
from django.apps import apps | ||
from django.core.checks import Tags, register | ||
from django.db import connections, router | ||
|
||
from django_mongodb_backend.indexes import VectorSearchIndex | ||
|
||
|
||
@register(Tags.models) | ||
def check_vector_search_indexes(app_configs, databases=None, **kwargs): # noqa: ARG001 | ||
# Validate vector search indexes for models. | ||
errors = [] | ||
if app_configs is None: | ||
models = apps.get_models() | ||
else: | ||
models = chain.from_iterable(app_config.get_models() for app_config in app_configs) | ||
for model in models: | ||
for db in databases or (): | ||
if not router.allow_migrate_model(db, model): | ||
continue | ||
connection = connections[db] | ||
for model_index in model._meta.indexes: | ||
if not isinstance(model_index, VectorSearchIndex): | ||
continue | ||
errors.extend(model_index.check(model, connection)) | ||
return errors |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from ..forms import SimpleArrayField | ||
from ..query_utils import process_lhs, process_rhs | ||
from ..utils import prefix_validation_error | ||
from .validators import LengthValidator | ||
|
||
__all__ = ["ArrayField"] | ||
|
||
|
@@ -27,13 +28,21 @@ class ArrayField(CheckFieldDefaultMixin, Field): | |
} | ||
_default_hint = ("list", "[]") | ||
|
||
def __init__(self, base_field, size=None, **kwargs): | ||
def __init__(self, base_field, max_size=None, size=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is just meant for vector_searching then can we have this be "fixed_size=True" rather than max_size? Then the validator would just use a ternary between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Array is not only for vector search, is for general purpose. Like array of array. Or array of char field or array of embedded fields |
||
self.base_field = base_field | ||
self.max_size = max_size | ||
self.size = size | ||
if size and max_size: | ||
raise ValueError("Cannot define both, size and max_size") | ||
if self.max_size: | ||
self.default_validators = [ | ||
*self.default_validators, | ||
ArrayMaxLengthValidator(self.max_size), | ||
] | ||
if self.size: | ||
self.default_validators = [ | ||
*self.default_validators, | ||
ArrayMaxLengthValidator(self.size), | ||
LengthValidator(self.size), | ||
] | ||
# For performance, only add a from_db_value() method if the base field | ||
# implements it. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from django.core.validators import BaseValidator | ||
from django.utils.deconstruct import deconstructible | ||
from django.utils.translation import ngettext_lazy | ||
|
||
|
||
@deconstructible | ||
class LengthValidator(BaseValidator): | ||
message = ngettext_lazy( | ||
"List contains %(show_value)d item, it should contain %(limit_value)d.", | ||
"List contains %(show_value)d items, it should contain %(limit_value)d.", | ||
"show_value", | ||
) | ||
code = "length" | ||
|
||
def compare(self, a, b): | ||
return a != b | ||
|
||
def clean(self, x): | ||
return len(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,16 @@ | ||
import itertools | ||
from collections import defaultdict | ||
|
||
from django.core.checks import Error | ||
from django.db import NotSupportedError | ||
from django.db.models import Index | ||
from django.db.models import DecimalField, FloatField, Index | ||
from django.db.models.lookups import BuiltinLookup | ||
from django.db.models.sql.query import Query | ||
from django.db.models.sql.where import AND, XOR, WhereNode | ||
from pymongo import ASCENDING, DESCENDING | ||
from pymongo.operations import IndexModel | ||
from pymongo.operations import IndexModel, SearchIndexModel | ||
|
||
from django_mongodb_backend.fields import ArrayField | ||
|
||
from .query_utils import process_rhs | ||
|
||
|
@@ -101,6 +105,136 @@ def where_node_idx(self, compiler, connection): | |
return mql | ||
|
||
|
||
class SearchIndex(Index): | ||
suffix = "six" | ||
|
||
# Maps Django internal type to atlas search index type. | ||
timgraham marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Reference: https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings/#data-types | ||
def search_index_data_types(self, field, db_type): | ||
if field.get_internal_type() == "UUIDField": | ||
return "uuid" | ||
if field.get_internal_type() in ("ObjectIdAutoField", "ObjectIdField"): | ||
return "ObjectId" | ||
if field.get_internal_type() == "EmbeddedModelField": | ||
return "embeddedDocuments" | ||
if db_type in ("int", "long"): | ||
return "number" | ||
if db_type == "binData": | ||
return "string" | ||
if db_type == "bool": | ||
return "boolean" | ||
if db_type == "object": | ||
return "document" | ||
return db_type | ||
|
||
def get_pymongo_index_model( | ||
self, model, schema_editor, field=None, unique=False, column_prefix="" | ||
): | ||
fields = {} | ||
for field_name, _ in self.fields_orders: | ||
field_ = model._meta.get_field(field_name) | ||
type_ = self.search_index_data_types(field_, field_.db_type(schema_editor.connection)) | ||
field_path = column_prefix + model._meta.get_field(field_name).column | ||
fields[field_path] = {"type": type_} | ||
return SearchIndexModel( | ||
definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name | ||
) | ||
|
||
|
||
class VectorSearchIndex(SearchIndex): | ||
suffix = "vsi" | ||
ALLOWED_SIMILARITY_FUNCTIONS = frozenset(("euclidean", "cosine", "dotProduct")) | ||
|
||
def __init__(self, *expressions, similarities="cosine", **kwargs): | ||
super().__init__(*expressions, **kwargs) | ||
# validate the similarities types | ||
self.similarities = similarities | ||
|
||
def check(self, model, connection): | ||
errors = [] | ||
error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex" | ||
similarities = ( | ||
self.similarities if isinstance(self.similarities, list) else [self.similarities] | ||
) | ||
for func in similarities: | ||
if func not in self.ALLOWED_SIMILARITY_FUNCTIONS: | ||
errors.append( | ||
Error( | ||
f"{func} isn't a valid similarity function, options " | ||
f"are {', '.join(sorted(self.ALLOWED_SIMILARITY_FUNCTIONS))}", | ||
obj=self, | ||
id=f"{error_id_prefix}.E004", | ||
) | ||
) | ||
for field_name, _ in self.fields_orders: | ||
field_ = model._meta.get_field(field_name) | ||
if isinstance(field_, ArrayField): | ||
try: | ||
int(field_.size) | ||
except (ValueError, TypeError): | ||
errors.append( | ||
Error( | ||
"Atlas vector search requires size.", | ||
obj=self, | ||
id=f"{error_id_prefix}.E001", | ||
) | ||
) | ||
if not isinstance(field_.base_field, FloatField | DecimalField): | ||
errors.append( | ||
Error( | ||
"Base type must be Float or Decimal.", | ||
obj=self, | ||
id=f"{error_id_prefix}.E002", | ||
) | ||
) | ||
else: | ||
field_type = field_.db_type(connection) | ||
search_type = self.search_index_data_types(field_, field_type) | ||
# filter - for fields that contain boolean, date, objectId, | ||
# numeric, string, or UUID values. Reference: | ||
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields | ||
if search_type not in ("number", "string", "boolean", "objectId", "uuid", "date"): | ||
errors.append( | ||
Error( | ||
f"Unsupported filter of type {field_.get_internal_type()}.", | ||
obj=self, | ||
id="django_mongodb_backend.indexes.VectorSearchIndex.E003", | ||
) | ||
) | ||
return errors | ||
|
||
def deconstruct(self): | ||
path, args, kwargs = super().deconstruct() | ||
kwargs["similarities"] = self.similarities | ||
return path, args, kwargs | ||
|
||
def get_pymongo_index_model( | ||
self, model, schema_editor, field=None, unique=False, column_prefix="" | ||
): | ||
similarities = ( | ||
itertools.cycle([self.similarities]) | ||
if isinstance(self.similarities, str) | ||
else iter(self.similarities) | ||
) | ||
fields = [] | ||
for field_name, _ in self.fields_orders: | ||
field_ = model._meta.get_field(field_name) | ||
field_path = column_prefix + model._meta.get_field(field_name).column | ||
mappings = {"path": field_path} | ||
if isinstance(field_, ArrayField): | ||
mappings.update( | ||
{ | ||
"type": "vector", | ||
"numDimensions": int(field_.size), | ||
"similarity": next(similarities), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still not sure why Another NIT: this mapping doesn't include optional quantization. I don't think will be necessary until we support BSONVector, but just wanted to note it here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, the code is a bit weird. Maybe it's not easy to read. |
||
} | ||
) | ||
else: | ||
mappings["type"] = "filter" | ||
fields.append(mappings) | ||
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch") | ||
|
||
|
||
def register_indexes(): | ||
BuiltinLookup.as_mql_idx = builtin_lookup_idx | ||
Index._get_condition_mql = _get_condition_mql | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
django>=5.1,<5.2 | ||
pymongo>=4.6,<5.0 | ||
pymongo>=4.7,<5.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,21 @@ | ||
from django.db import models | ||
|
||
from django_mongodb_backend.fields import ArrayField, EmbeddedModelField | ||
from django_mongodb_backend.models import EmbeddedModel | ||
|
||
|
||
class Data(EmbeddedModel): | ||
integer = models.IntegerField() | ||
|
||
|
||
class Article(models.Model): | ||
headline = models.CharField(max_length=100) | ||
number = models.IntegerField() | ||
body = models.TextField() | ||
data = models.JSONField() | ||
embedded = EmbeddedModelField(Data) | ||
auto_now = models.DateTimeField(auto_now=True) | ||
title_embedded = ArrayField(models.FloatField(), size=10) | ||
description_embedded = ArrayField(models.FloatField(), size=10) | ||
number_list = ArrayField(models.FloatField()) | ||
name_list = ArrayField(models.CharField(max_length=30), size=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a comment explaining this parameter is necessary against local atlas instances running in docker.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docker site don't provide much info about it. Just give the connection uri format.
I think is to connect to a single MongoDB instance like standalone mode. But need to confirm that.