diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index d9dd5b6c..ee5c0c30 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -3,10 +3,13 @@ from django.core import checks from django.core.exceptions import FieldDoesNotExist from django.db import models +from django.db.models import lookups +from django.db.models.expressions import Col from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Transform from .. import forms +from ..query_utils import process_lhs, process_rhs from .json import build_json_mql_path @@ -149,6 +152,66 @@ def formfield(self, **kwargs): ) +@EmbeddedModelField.register_lookup +class EMFExact(lookups.Exact): + def model_to_dict(self, instance): + """ + Return a dict containing the data in a model instance, as well as a + dict containing the data for any embedded model fields. + """ + data = {} + emf_data = {} + for f in instance._meta.concrete_fields: + value = f.value_from_object(instance) + if isinstance(f, EmbeddedModelField): + emf_data[f.name] = self.model_to_dict(value) if value is not None else (None, {}) + continue + # Unless explicitly set, primary keys aren't included in embedded + # models. + if f.primary_key and value is None: + continue + data[f.name] = value + return data, emf_data + + def get_conditions(self, emf_data, prefix=None): + """ + Recursively transform a dictionary of {"field_name": {}} + lookups into MQL. `prefix` tracks the string that must be appended to + nested fields. + """ + conditions = [] + for k, v in emf_data.items(): + v, emf_data = v + subprefix = f"{prefix}.{k}" if prefix else k + conditions += self.get_conditions(emf_data, subprefix) + if v is not None: + # Match all field of the EmbeddedModelField. + conditions += [{"$eq": [f"{subprefix}.{x}", y]} for x, y in v.items()] + else: + # Match a null EmbeddedModelField. + conditions += [{"$eq": [f"{subprefix}", None]}] + return conditions + + def as_mql(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + if isinstance(self.lhs, Col) or ( + isinstance(self.lhs, KeyTransform) + and isinstance(self.lhs.ref_field, EmbeddedModelField) + ): + if isinstance(value, models.Model): + value, emf_data = self.model_to_dict(value) + prefix = self.lhs.as_mql(compiler, connection) + # Get conditions for any nested EmbeddedModelFields. + conditions = self.get_conditions({prefix: (value, emf_data)}) + return {"$and": conditions} + raise TypeError( + "An EmbeddedModelField must be queried using a model instance, got %s." + % type(value) + ) + return connection.mongo_operators[self.lookup_name](lhs_mql, value) + + class KeyTransform(Transform): def __init__(self, key_name, ref_field, *args, **kwargs): super().__init__(*args, **kwargs) @@ -192,8 +255,14 @@ def preprocess_lhs(self, compiler, connection): json_key_transforms.insert(0, previous.key_name) previous = previous.lhs mql = previous.as_mql(compiler, connection) - # The first json_key_transform is the field name. - embedded_key_transforms.append(json_key_transforms.pop(0)) + try: + # The first json_key_transform is the field name. + field_name = json_key_transforms.pop(0) + except IndexError: + # This is a lookup of the embedded model itself. + pass + else: + embedded_key_transforms.append(field_name) return mql, embedded_key_transforms, json_key_transforms def as_mql(self, compiler, connection): diff --git a/docs/source/releases/5.1.x.rst b/docs/source/releases/5.1.x.rst index d13a59a5..8a94c91e 100644 --- a/docs/source/releases/5.1.x.rst +++ b/docs/source/releases/5.1.x.rst @@ -2,6 +2,14 @@ Django MongoDB Backend 5.1.x ============================ +5.1.0 beta 2 +============ + +*Unreleased* + +- Added support to ``EmbeddedModelField`` for the ``QuerySet`` ``exact`` lookup + using a model instance. + 5.1.0 beta 1 ============ diff --git a/docs/source/topics/embedded-models.rst b/docs/source/topics/embedded-models.rst index 94abecfd..2a406312 100644 --- a/docs/source/topics/embedded-models.rst +++ b/docs/source/topics/embedded-models.rst @@ -54,3 +54,12 @@ as relational fields. For example, to retrieve all customers who have an address with the city "New York":: >>> Customer.objects.filter(address__city="New York") + +You can also query using a model instance. Unlike a normal relational lookup +which does the lookup by primary key, since embedded models typically don't +have a primary key set, the query requires that every field match. For example, +this query gives customers with addresses with the city "New York" and all +other fields of the address equal to their default (:attr:`Field.default +`, ``None``, or an empty string). + + >>> Customer.objects.filter(address=Address(city="New York")) diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index b25b94a1..265271c0 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -130,3 +130,31 @@ class Library(models.Model): def __str__(self): return self.name + + +class A(models.Model): + b = EmbeddedModelField("B") + + +class B(EmbeddedModel): + c = EmbeddedModelField("C") + name = models.CharField(max_length=100) + value = models.IntegerField() + + +class C(EmbeddedModel): + d = EmbeddedModelField("D") + name = models.CharField(max_length=100) + value = models.IntegerField() + + +class D(EmbeddedModel): + e = EmbeddedModelField("E") + nullable_e = EmbeddedModelField("E", null=True, blank=True) + name = models.CharField(max_length=100) + value = models.IntegerField() + + +class E(EmbeddedModel): + name = models.CharField(max_length=100) + value = models.IntegerField() diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index eee0dd1a..4ac6ab93 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -2,7 +2,7 @@ from datetime import timedelta from django.core.exceptions import FieldDoesNotExist, ValidationError -from django.db import models +from django.db import connection, models from django.db.models import ( Exists, ExpressionWrapper, @@ -17,14 +17,7 @@ from django_mongodb_backend.fields import EmbeddedModelField from django_mongodb_backend.models import EmbeddedModel -from .models import ( - Address, - Author, - Book, - Data, - Holder, - Library, -) +from .models import A, Address, Author, B, Book, C, D, Data, E, Holder, Library from .utils import truncate_ms @@ -117,6 +110,66 @@ def test_order_by_embedded_field(self): qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer") self.assertSequenceEqual(qs, list(reversed(self.objs[4:]))) + def test_exact_with_model(self): + data = Holder.objects.first().data + self.assertEqual( + Holder.objects.filter(data=data).get().data.integer, self.objs[0].data.integer + ) + + def test_exact_with_model_ignores_key_order(self): + # Due to the possibility of schema changes or the reordering of a + # model's fields, a lookup must work if an embedded document has its + # keys in a different order than what's declared on the embedded model. + connection.get_collection("model_fields__holder").insert_one( + { + "data": { + "auto_now": None, + "auto_now_add": None, + "json_value": None, + "integer": 100, + } + } + ) + self.assertEqual(Holder.objects.filter(data=Data(integer=100)).get().data.integer, 100) + + def test_exact_with_nested_model(self): + address = Address(city="NYC", state="NY") + author = Author(name="Shakespeare", age=55, address=address) + obj = Book.objects.create(author=author) + self.assertCountEqual(Book.objects.filter(author=author), [obj]) + self.assertCountEqual(Book.objects.filter(author__address=address), [obj]) + + def test_exact_with_deeply_nested_models(self): + e1 = E(name="E1", value=5) + d1 = D(name="D1", value=4, e=e1) + c1 = C(name="C1", value=3, d=d1) + b1 = B(name="B1", value=2, c=c1) + a1 = A.objects.create(b=b1) + e2 = E(name="E2", value=6) + d2 = D(name="D2", value=4, e=e1, nullable_e=e2) + c2 = C(name="C2", value=3, d=d2) + b2 = B(name="B2", value=2, c=c2) + a2 = A.objects.create(b=b2) + self.assertCountEqual(A.objects.filter(b=b1), [a1]) + self.assertCountEqual(A.objects.filter(b__c=c1), [a1]) + self.assertCountEqual(A.objects.filter(b__c__d=d1), [a1]) + self.assertCountEqual(A.objects.filter(b__c__d__e=e1), [a1, a2]) + self.assertCountEqual(A.objects.filter(b=b2), [a2]) + self.assertCountEqual(A.objects.filter(b__c=c2), [a2]) + self.assertCountEqual(A.objects.filter(b__c__d=d2), [a2]) + self.assertCountEqual(A.objects.filter(b__c__d__nullable_e=e2), [a2]) + + def test_exact_validates_argument(self): + msg = "An EmbeddedModelField must be queried using a model instance, got ." + with self.assertRaisesMessage(TypeError, msg): + str(A.objects.filter(b={})) + with self.assertRaisesMessage(TypeError, msg): + str(A.objects.filter(b__c={})) + with self.assertRaisesMessage(TypeError, msg): + str(A.objects.filter(b__c__d={})) + with self.assertRaisesMessage(TypeError, msg): + str(A.objects.filter(b__c__d__e={})) + def test_embedded_json_field_lookups(self): objs = [ Holder.objects.create(