Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow querying an EmbeddedModelField by model instance #267

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
from django.core import checks
from django.core.exceptions import FieldDoesNotExist
from django.db import models
from django.db.models import lookups
from django.db.models.expressions import Col
from django.db.models.fields.related import lazy_related_operation
from django.db.models.lookups import Transform

from .. import forms
from ..query_utils import process_lhs, process_rhs
from .json import build_json_mql_path


Expand Down Expand Up @@ -149,6 +152,66 @@ def formfield(self, **kwargs):
)


@EmbeddedModelField.register_lookup
class EMFExact(lookups.Exact):
def model_to_dict(self, instance):
"""
Return a dict containing the data in a model instance, as well as a
dict containing the data for any embedded model fields.
"""
data = {}
emf_data = {}
for f in instance._meta.concrete_fields:
value = f.value_from_object(instance)
if isinstance(f, EmbeddedModelField):
emf_data[f.name] = self.model_to_dict(value) if value is not None else (None, {})
continue
# Unless explicitly set, primary keys aren't included in embedded
# models.
if f.primary_key and value is None:
continue
data[f.name] = value
return data, emf_data

def get_conditions(self, emf_data, prefix=None):
"""
Recursively transform a dictionary of {"field_name": {<model_to_dict>}}
lookups into MQL. `prefix` tracks the string that must be appended to
nested fields.
"""
conditions = []
for k, v in emf_data.items():
v, emf_data = v
subprefix = f"{prefix}.{k}" if prefix else k
conditions += self.get_conditions(emf_data, subprefix)
if v is not None:
# Match all field of the EmbeddedModelField.
conditions += [{"$eq": [f"{subprefix}.{x}", y]} for x, y in v.items()]
else:
# Match a null EmbeddedModelField.
conditions += [{"$eq": [f"{subprefix}", None]}]
return conditions

def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
if isinstance(self.lhs, Col) or (
isinstance(self.lhs, KeyTransform)
and isinstance(self.lhs.ref_field, EmbeddedModelField)
):
if isinstance(value, models.Model):
value, emf_data = self.model_to_dict(value)
prefix = self.lhs.as_mql(compiler, connection)
# Get conditions for any nested EmbeddedModelFields.
conditions = self.get_conditions({prefix: (value, emf_data)})
return {"$and": conditions}
raise TypeError(
"An EmbeddedModelField must be queried using a model instance, got %s."
% type(value)
)
return connection.mongo_operators[self.lookup_name](lhs_mql, value)


class KeyTransform(Transform):
def __init__(self, key_name, ref_field, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -192,8 +255,14 @@ def preprocess_lhs(self, compiler, connection):
json_key_transforms.insert(0, previous.key_name)
previous = previous.lhs
mql = previous.as_mql(compiler, connection)
# The first json_key_transform is the field name.
embedded_key_transforms.append(json_key_transforms.pop(0))
try:
# The first json_key_transform is the field name.
field_name = json_key_transforms.pop(0)
except IndexError:
# This is a lookup of the embedded model itself.
pass
else:
embedded_key_transforms.append(field_name)
return mql, embedded_key_transforms, json_key_transforms

def as_mql(self, compiler, connection):
Expand Down
8 changes: 8 additions & 0 deletions docs/source/releases/5.1.x.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
Django MongoDB Backend 5.1.x
============================

5.1.0 beta 2
============

*Unreleased*

- Added support to ``EmbeddedModelField`` for the ``QuerySet`` ``exact`` lookup
using a model instance.

5.1.0 beta 1
============

Expand Down
9 changes: 9 additions & 0 deletions docs/source/topics/embedded-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ as relational fields. For example, to retrieve all customers who have an
address with the city "New York"::

>>> Customer.objects.filter(address__city="New York")

You can also query using a model instance. Unlike a normal relational lookup
which does the lookup by primary key, since embedded models typically don't
have a primary key set, the query requires that every field match. For example,
this query gives customers with addresses with the city "New York" and all
other fields of the address equal to their default (:attr:`Field.default
<django.db.models.Field.default>`, ``None``, or an empty string).

>>> Customer.objects.filter(address=Address(city="New York"))
28 changes: 28 additions & 0 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,31 @@ class Library(models.Model):

def __str__(self):
return self.name


class A(models.Model):
b = EmbeddedModelField("B")


class B(EmbeddedModel):
c = EmbeddedModelField("C")
name = models.CharField(max_length=100)
value = models.IntegerField()


class C(EmbeddedModel):
d = EmbeddedModelField("D")
name = models.CharField(max_length=100)
value = models.IntegerField()


class D(EmbeddedModel):
e = EmbeddedModelField("E")
nullable_e = EmbeddedModelField("E", null=True, blank=True)
name = models.CharField(max_length=100)
value = models.IntegerField()


class E(EmbeddedModel):
name = models.CharField(max_length=100)
value = models.IntegerField()
71 changes: 62 additions & 9 deletions tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import timedelta

from django.core.exceptions import FieldDoesNotExist, ValidationError
from django.db import models
from django.db import connection, models
from django.db.models import (
Exists,
ExpressionWrapper,
Expand All @@ -17,14 +17,7 @@
from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

from .models import (
Address,
Author,
Book,
Data,
Holder,
Library,
)
from .models import A, Address, Author, B, Book, C, D, Data, E, Holder, Library
from .utils import truncate_ms


Expand Down Expand Up @@ -117,6 +110,66 @@ def test_order_by_embedded_field(self):
qs = Holder.objects.filter(data__integer__gt=3).order_by("-data__integer")
self.assertSequenceEqual(qs, list(reversed(self.objs[4:])))

def test_exact_with_model(self):
data = Holder.objects.first().data
self.assertEqual(
Holder.objects.filter(data=data).get().data.integer, self.objs[0].data.integer
)

def test_exact_with_model_ignores_key_order(self):
# Due to the possibility of schema changes or the reordering of a
# model's fields, a lookup must work if an embedded document has its
# keys in a different order than what's declared on the embedded model.
connection.get_collection("model_fields__holder").insert_one(
{
"data": {
"auto_now": None,
"auto_now_add": None,
"json_value": None,
"integer": 100,
}
}
)
self.assertEqual(Holder.objects.filter(data=Data(integer=100)).get().data.integer, 100)

def test_exact_with_nested_model(self):
address = Address(city="NYC", state="NY")
author = Author(name="Shakespeare", age=55, address=address)
obj = Book.objects.create(author=author)
self.assertCountEqual(Book.objects.filter(author=author), [obj])
self.assertCountEqual(Book.objects.filter(author__address=address), [obj])

def test_exact_with_deeply_nested_models(self):
e1 = E(name="E1", value=5)
d1 = D(name="D1", value=4, e=e1)
c1 = C(name="C1", value=3, d=d1)
b1 = B(name="B1", value=2, c=c1)
a1 = A.objects.create(b=b1)
e2 = E(name="E2", value=6)
d2 = D(name="D2", value=4, e=e1, nullable_e=e2)
c2 = C(name="C2", value=3, d=d2)
b2 = B(name="B2", value=2, c=c2)
a2 = A.objects.create(b=b2)
self.assertCountEqual(A.objects.filter(b=b1), [a1])
self.assertCountEqual(A.objects.filter(b__c=c1), [a1])
self.assertCountEqual(A.objects.filter(b__c__d=d1), [a1])
self.assertCountEqual(A.objects.filter(b__c__d__e=e1), [a1, a2])
self.assertCountEqual(A.objects.filter(b=b2), [a2])
self.assertCountEqual(A.objects.filter(b__c=c2), [a2])
self.assertCountEqual(A.objects.filter(b__c__d=d2), [a2])
self.assertCountEqual(A.objects.filter(b__c__d__nullable_e=e2), [a2])

def test_exact_validates_argument(self):
msg = "An EmbeddedModelField must be queried using a model instance, got <class 'dict'>."
with self.assertRaisesMessage(TypeError, msg):
str(A.objects.filter(b={}))
with self.assertRaisesMessage(TypeError, msg):
str(A.objects.filter(b__c={}))
with self.assertRaisesMessage(TypeError, msg):
str(A.objects.filter(b__c__d={}))
with self.assertRaisesMessage(TypeError, msg):
str(A.objects.filter(b__c__d__e={}))

def test_embedded_json_field_lookups(self):
objs = [
Holder.objects.create(
Expand Down