|
3 | 3 | from django.core import checks
|
4 | 4 | from django.core.exceptions import FieldDoesNotExist
|
5 | 5 | from django.db import models
|
| 6 | +from django.db.models import lookups |
6 | 7 | from django.db.models.fields.related import lazy_related_operation
|
7 | 8 | from django.db.models.lookups import Transform
|
8 | 9 |
|
9 | 10 | from .. import forms
|
| 11 | +from ..query_utils import process_lhs, process_rhs |
10 | 12 | from .json import build_json_mql_path
|
11 | 13 |
|
12 | 14 |
|
@@ -149,6 +151,62 @@ def formfield(self, **kwargs):
|
149 | 151 | )
|
150 | 152 |
|
151 | 153 |
|
| 154 | +@EmbeddedModelField.register_lookup |
| 155 | +class EMFExact(lookups.Exact): |
| 156 | + def model_to_dict(self, instance): |
| 157 | + """ |
| 158 | + Return a dict containing the data in a model instance, as well as a |
| 159 | + dict containing the data for any embedded model fields. |
| 160 | + """ |
| 161 | + data = {} |
| 162 | + emf_data = {} |
| 163 | + for f in instance._meta.concrete_fields: |
| 164 | + value = f.value_from_object(instance) |
| 165 | + if isinstance(f, EmbeddedModelField): |
| 166 | + emf_data[f"{f.name}"] = ( |
| 167 | + self.model_to_dict(value) if value is not None else (None, {}) |
| 168 | + ) |
| 169 | + continue |
| 170 | + # Unless explicitly set, primary keys aren't included in embedded |
| 171 | + # models. |
| 172 | + if f.primary_key and value is None: |
| 173 | + continue |
| 174 | + data[f"{f.name}"] = value |
| 175 | + return data, emf_data |
| 176 | + |
| 177 | + def get_conditions(self, emf_data, prefix): |
| 178 | + """ |
| 179 | + Recursively transform a dictionary of {"field_name": {<model_to_dict>}} |
| 180 | + into MQL lookups. `prefix` tracks the string that must be appended to |
| 181 | + nested fields. |
| 182 | + """ |
| 183 | + conditions = [] |
| 184 | + for k, v in emf_data.items(): |
| 185 | + v, emf_data = v |
| 186 | + subprefix = f"{prefix}.{k}" |
| 187 | + conditions += self.get_conditions(emf_data, subprefix) |
| 188 | + if v is not None: |
| 189 | + # Match all field of the EmbeddedModelField. |
| 190 | + conditions += [{"$eq": [f"{subprefix}.{x}", y]} for x, y in v.items()] |
| 191 | + else: |
| 192 | + # Match a null EmbeddedModelField. |
| 193 | + conditions += [{"$eq": [f"{subprefix}", None]}] |
| 194 | + return conditions |
| 195 | + |
| 196 | + def as_mql(self, compiler, connection): |
| 197 | + lhs_mql = process_lhs(self, compiler, connection) |
| 198 | + value = process_rhs(self, compiler, connection) |
| 199 | + if isinstance(value, models.Model): |
| 200 | + value, emf_data = self.model_to_dict(value) |
| 201 | + prefix = self.lhs.as_mql(compiler, connection) |
| 202 | + # Get conditions for top-level EmbeddedModelField. |
| 203 | + conditions = [{"$eq": [f"{prefix}.{k}", v]} for k, v in value.items()] |
| 204 | + # Get conditions for any nested EmbeddedModelFields. |
| 205 | + conditions += self.get_conditions(emf_data, prefix) |
| 206 | + return {"$and": conditions} |
| 207 | + return connection.mongo_operators[self.lookup_name](lhs_mql, value) |
| 208 | + |
| 209 | + |
152 | 210 | class KeyTransform(Transform):
|
153 | 211 | def __init__(self, key_name, ref_field, *args, **kwargs):
|
154 | 212 | super().__init__(*args, **kwargs)
|
@@ -192,8 +250,14 @@ def preprocess_lhs(self, compiler, connection):
|
192 | 250 | json_key_transforms.insert(0, previous.key_name)
|
193 | 251 | previous = previous.lhs
|
194 | 252 | mql = previous.as_mql(compiler, connection)
|
195 |
| - # The first json_key_transform is the field name. |
196 |
| - embedded_key_transforms.append(json_key_transforms.pop(0)) |
| 253 | + try: |
| 254 | + # The first json_key_transform is the field name. |
| 255 | + field_name = json_key_transforms.pop(0) |
| 256 | + except IndexError: |
| 257 | + # This is a lookup of the embedded model itself. |
| 258 | + pass |
| 259 | + else: |
| 260 | + embedded_key_transforms.append(field_name) |
197 | 261 | return mql, embedded_key_transforms, json_key_transforms
|
198 | 262 |
|
199 | 263 | def as_mql(self, compiler, connection):
|
|
0 commit comments