Skip to content

Commit d184b56

Browse files
committed
Allow querying an EmbeddedModelField by model instance
1 parent 8a57a06 commit d184b56

File tree

5 files changed

+167
-11
lines changed

5 files changed

+167
-11
lines changed

django_mongodb_backend/fields/embedded_model.py

+66-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from django.core import checks
44
from django.core.exceptions import FieldDoesNotExist
55
from django.db import models
6+
from django.db.models import lookups
67
from django.db.models.fields.related import lazy_related_operation
78
from django.db.models.lookups import Transform
89

910
from .. import forms
11+
from ..query_utils import process_lhs, process_rhs
1012
from .json import build_json_mql_path
1113

1214

@@ -149,6 +151,62 @@ def formfield(self, **kwargs):
149151
)
150152

151153

154+
@EmbeddedModelField.register_lookup
155+
class EMFExact(lookups.Exact):
156+
def model_to_dict(self, instance):
157+
"""
158+
Return a dict containing the data in a model instance, as well as a
159+
dict containing the data for any embedded model fields.
160+
"""
161+
data = {}
162+
emf_data = {}
163+
for f in instance._meta.concrete_fields:
164+
value = f.value_from_object(instance)
165+
if isinstance(f, EmbeddedModelField):
166+
emf_data[f"{f.name}"] = (
167+
self.model_to_dict(value) if value is not None else (None, {})
168+
)
169+
continue
170+
# Unless explicitly set, primary keys aren't included in embedded
171+
# models.
172+
if f.primary_key and value is None:
173+
continue
174+
data[f"{f.name}"] = value
175+
return data, emf_data
176+
177+
def get_conditions(self, emf_data, prefix):
178+
"""
179+
Recursively transform a dictionary of {"field_name": {<model_to_dict>}}
180+
into MQL lookups. `prefix` tracks the string that must be appended to
181+
nested fields.
182+
"""
183+
conditions = []
184+
for k, v in emf_data.items():
185+
v, emf_data = v
186+
subprefix = f"{prefix}.{k}"
187+
conditions += self.get_conditions(emf_data, subprefix)
188+
if v is not None:
189+
# Match all field of the EmbeddedModelField.
190+
conditions += [{"$eq": [f"{subprefix}.{x}", y]} for x, y in v.items()]
191+
else:
192+
# Match a null EmbeddedModelField.
193+
conditions += [{"$eq": [f"{subprefix}", None]}]
194+
return conditions
195+
196+
def as_mql(self, compiler, connection):
197+
lhs_mql = process_lhs(self, compiler, connection)
198+
value = process_rhs(self, compiler, connection)
199+
if isinstance(value, models.Model):
200+
value, emf_data = self.model_to_dict(value)
201+
prefix = self.lhs.as_mql(compiler, connection)
202+
# Get conditions for top-level EmbeddedModelField.
203+
conditions = [{"$eq": [f"{prefix}.{k}", v]} for k, v in value.items()]
204+
# Get conditions for any nested EmbeddedModelFields.
205+
conditions += self.get_conditions(emf_data, prefix)
206+
return {"$and": conditions}
207+
return connection.mongo_operators[self.lookup_name](lhs_mql, value)
208+
209+
152210
class KeyTransform(Transform):
153211
def __init__(self, key_name, ref_field, *args, **kwargs):
154212
super().__init__(*args, **kwargs)
@@ -192,8 +250,14 @@ def preprocess_lhs(self, compiler, connection):
192250
json_key_transforms.insert(0, previous.key_name)
193251
previous = previous.lhs
194252
mql = previous.as_mql(compiler, connection)
195-
# The first json_key_transform is the field name.
196-
embedded_key_transforms.append(json_key_transforms.pop(0))
253+
try:
254+
# The first json_key_transform is the field name.
255+
field_name = json_key_transforms.pop(0)
256+
except IndexError:
257+
# This is a lookup of the embedded model itself.
258+
pass
259+
else:
260+
embedded_key_transforms.append(field_name)
197261
return mql, embedded_key_transforms, json_key_transforms
198262

199263
def as_mql(self, compiler, connection):

docs/source/releases/5.1.x.rst

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
Django MongoDB Backend 5.1.x
33
============================
44

5+
5.1.0 beta 2
6+
============
7+
8+
*Unreleased*
9+
10+
- Added support to ``EmbeddedModelField`` for the ``QuerySet`` ``exact`` lookup
11+
using a model instance.
12+
513
5.1.0 beta 1
614
============
715

docs/source/topics/embedded-models.rst

+9
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,12 @@ as relational fields. For example, to retrieve all customers who have an
5454
address with the city "New York"::
5555

5656
>>> Customer.objects.filter(address__city="New York")
57+
58+
You can also query using a model instance. Unlike a normal relational lookup
59+
which does the lookup by primary key, since embedded models typically don't
60+
have a primary key set, the query requires that every field match. For example,
61+
this query gives customers with addresses with the city "New York" and all
62+
other fields of the address equal to their default (:attr:`Field.default
63+
<django.db.models.Field.default>`, ``None``, or an empty string).
64+
65+
>>> Customer.objects.filter(address=Address(city="New York"))

tests/model_fields_/models.py

+28
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,31 @@ class Library(models.Model):
130130

131131
def __str__(self):
132132
return self.name
133+
134+
135+
class A(models.Model):
136+
b = EmbeddedModelField("B")
137+
138+
139+
class B(EmbeddedModel):
140+
c = EmbeddedModelField("C")
141+
name = models.CharField(max_length=100)
142+
value = models.IntegerField()
143+
144+
145+
class C(EmbeddedModel):
146+
d = EmbeddedModelField("D")
147+
name = models.CharField(max_length=100)
148+
value = models.IntegerField()
149+
150+
151+
class D(EmbeddedModel):
152+
e = EmbeddedModelField("E")
153+
nullable_e = EmbeddedModelField("E", null=True, blank=True)
154+
name = models.CharField(max_length=100)
155+
value = models.IntegerField()
156+
157+
158+
class E(EmbeddedModel):
159+
name = models.CharField(max_length=100)
160+
value = models.IntegerField()

tests/model_fields_/test_embedded_model.py

+56-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import timedelta
33

44
from django.core.exceptions import FieldDoesNotExist, ValidationError
5-
from django.db import models
5+
from django.db import connection, models
66
from django.db.models import (
77
Exists,
88
ExpressionWrapper,
@@ -17,14 +17,7 @@
1717
from django_mongodb_backend.fields import EmbeddedModelField
1818
from django_mongodb_backend.models import EmbeddedModel
1919

20-
from .models import (
21-
Address,
22-
Author,
23-
Book,
24-
Data,
25-
Holder,
26-
Library,
27-
)
20+
from .models import A, Address, Author, B, Book, C, D, Data, E, Holder, Library
2821
from .utils import truncate_ms
2922

3023

@@ -117,6 +110,60 @@ def test_order_by_embedded_field(self):
117110
qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer")
118111
self.assertSequenceEqual(qs, list(reversed(self.objs[4:])))
119112

113+
def test_exact_with_model(self):
114+
data = Holder.objects.first().data
115+
self.assertEqual(
116+
Holder.objects.filter(data=data).get().data.integer, self.objs[0].data.integer
117+
)
118+
119+
def test_exact_with_model_ignores_key_order(self):
120+
# Due to the possibility of schema changes or the reordering of a
121+
# model's fields, a lookup must work if an embedded document has its
122+
# keys in a different order than what's declared on the embedded model.
123+
connection.get_collection("model_fields__holder").insert_one(
124+
{
125+
"data": {
126+
"auto_now": None,
127+
"auto_now_add": None,
128+
"json_value": None,
129+
"integer": 100,
130+
}
131+
}
132+
)
133+
self.assertEqual(Holder.objects.filter(data=Data(integer=100)).get().data.integer, 100)
134+
135+
def test_exact_with_nested_model(self):
136+
address = Address(city="NYC", state="NY")
137+
author = Author(name="Shakespeare", age=55, address=address)
138+
obj = Book.objects.create(author=author)
139+
self.assertCountEqual(Book.objects.filter(author=address), [obj])
140+
self.assertCountEqual(Book.objects.filter(author__address=address), [obj])
141+
142+
def test_exact_with_deeply_nested_model(self):
143+
e1 = E(name="E1", value=5)
144+
d1 = D(name="D1", value=4, e=e1)
145+
c1 = C(name="C1", value=3, d=d1)
146+
b1 = B(name="B1", value=2, c=c1)
147+
a1 = A.objects.create(b=b1)
148+
e2 = E(name="E2", value=6)
149+
d2 = D(name="D2", value=4, e=e1, nullable_e=e2)
150+
c2 = C(name="C2", value=3, d=d2)
151+
b2 = B(name="B2", value=2, c=c2)
152+
a2 = A.objects.create(b=b2)
153+
self.assertCountEqual(A.objects.filter(b=b1), [a1])
154+
self.assertCountEqual(A.objects.filter(b__c=c1), [a1])
155+
self.assertCountEqual(A.objects.filter(b__c__d=d1), [a1])
156+
self.assertCountEqual(A.objects.filter(b__c__d__e=e1), [a1, a2])
157+
self.assertCountEqual(A.objects.filter(b=b2), [a2])
158+
self.assertCountEqual(A.objects.filter(b__c=c2), [a2])
159+
self.assertCountEqual(A.objects.filter(b__c__d=d2), [a2])
160+
self.assertCountEqual(A.objects.filter(b__c__d__nullable_e=e2), [a2])
161+
162+
def test_exact_with_model_with_embedded_model(self):
163+
author = Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY"))
164+
obj = Book.objects.create(author=author)
165+
self.assertCountEqual(Book.objects.filter(author=author), [obj])
166+
120167
def test_embedded_json_field_lookups(self):
121168
objs = [
122169
Holder.objects.create(

0 commit comments

Comments
 (0)