Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Join style queries for dict, set, list, and embedded fields #178

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Contributions by
* Sabin Iacob (https://github.com/m0n5t3r)
* kryton (https://github.com/kryton)
* Brandon Pedersen (https://github.com/bpedman)
* Brian Gontowski (https://github.com/Molanda)

(For an up-to-date list of contributors, see
https://github.com/django-mongodb-engine/mongodb-engine/contributors.)
3 changes: 3 additions & 0 deletions django_mongodb_engine/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,9 @@ def insert(self, docs, return_id=False):
doc.clear()
else:
raise DatabaseError("Can't save entity with _id set to None")
for d in doc.keys():
if '.' in d:
del doc[d]

collection = self.get_collection()
options = self.connection.operation_flags.get('save', {})
Expand Down
120 changes: 120 additions & 0 deletions django_mongodb_engine/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import sys
import re
import copy
import django

from django.db import models, connections
from django.db.models.query import QuerySet
from django.db.models.sql.query import Query as SQLQuery
from django.db.models.query_utils import Q
from django_mongodb_engine.compiler import OPERATORS_MAP, NEGATED_OPERATORS_MAP
from djangotoolbox.fields import AbstractIterableField

if django.VERSION >= (1, 5):
from django.db.models.constants import LOOKUP_SEP
else:
from django.db.models.sql.constants import LOOKUP_SEP


ON_PYPY = hasattr(sys, 'pypy_version_info')
ALL_OPERATORS = dict(list(OPERATORS_MAP.items() + NEGATED_OPERATORS_MAP.items())).keys()
MONGO_DOT_FIELDS = ('DictField', 'ListField', 'SetField', 'EmbeddedModelField')


def _compiler_for_queryset(qs, which='SQLCompiler'):
Expand Down Expand Up @@ -85,6 +98,113 @@ def __repr__(self):

class MongoDBQuerySet(QuerySet):

def _filter_or_exclude(self, negate, *args, **kwargs):
if args or kwargs:
assert self.query.can_filter(), \
'Cannot filter a query once a slice has been taken.'

clone = self._clone()
clone._process_arg_filters(args, kwargs)
if negate:
clone.query.add_q(~Q(*args, **kwargs))
else:
clone.query.add_q(Q(*args, **kwargs))
return clone

def _get_mongo_field_names(self):
if not hasattr(self, '_mongo_field_names'):
self._mongo_field_names = []
for name in self.model._meta.get_all_field_names():
field = self.model._meta.get_field_by_name(name)[0]
if '.' not in name and field.get_internal_type() in MONGO_DOT_FIELDS:
self._mongo_field_names.append(name)

return self._mongo_field_names

def _process_arg_filters(self, args, kwargs):
for key, val in kwargs.items():
del kwargs[key]
key = self._maybe_add_dot_field(key)
kwargs[key] = val

for a in args:
if isinstance(a, Q):
self._process_q_filters(a)

def _process_q_filters(self, q):
for c in range(len(q.children)):
child = q.children[c]
if isinstance(child, Q):
self._process_q_filters(child)
elif isinstance(child, tuple):
key, val = child
key = self._maybe_add_dot_field(key)
q.children[c] = (key, val)

def _maybe_add_dot_field(self, name):
if LOOKUP_SEP in name and name.split(LOOKUP_SEP)[0] in self._get_mongo_field_names():
for op in ALL_OPERATORS:
if name.endswith(LOOKUP_SEP + op):
name = re.sub(LOOKUP_SEP + op + '$', '#' + op, name)
break
name = name.replace(LOOKUP_SEP, '.').replace('#', LOOKUP_SEP)

parts1 = name.split(LOOKUP_SEP)
if '.' in parts1[0] and parts1[0] not in self.model._meta.get_all_field_names():
parts2 = parts1[0].split('.')
parts3 = []
parts4 = []
model = self.model

while len(parts2) > 0:
part = parts2.pop(0)
field = model._meta.get_field_by_name(part)[0]
field_type = field.get_internal_type()
column = field.db_column
if column:
part = column
parts3.append(part)
if field_type == 'ListField':
list_type = field.item_field.get_internal_type()
if list_type == 'EmbeddedModelField':
field = field.item_field
field_type = list_type
if field_type == 'EmbeddedModelField':
model = field.embedded_model()
else:
while len(parts2) > 0:
part = parts2.pop(0)
if field_type in MONGO_DOT_FIELDS:
parts3.append(part)
else:
parts4.append(part)

db_column = '.'.join(parts3)

if field_type in MONGO_DOT_FIELDS:
field = AbstractIterableField(
db_column=db_column,
blank=True,
null=True,
editable=False,
)
else:
field = copy.deepcopy(field)
field.name = None
field.db_column = db_column
field.blank = True
field.null = True
field.editable = False
if hasattr(field, '_related_fields'):
delattr(field, '_related_fields')

parts5 = parts1[0].split('.')[0:len(parts3)]
name = '.'.join(parts5)
self.model.add_to_class(name, field)
name = LOOKUP_SEP.join([name] + parts4 + parts1[1:])

return name

def map_reduce(self, *args, **kwargs):
"""
Performs a Map/Reduce operation on all documents matching the query,
Expand Down
Empty file added tests/dotquery/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions tests/dotquery/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from django.db import models
from djangotoolbox.fields import ListField, DictField, EmbeddedModelField
from django_mongodb_engine.contrib import MongoDBManager


class DotQueryForeignModel(models.Model):
objects = MongoDBManager()

f_char = models.CharField(max_length=200, db_column='dbc_char')


class DotQueryEmbeddedModel(models.Model):
objects = MongoDBManager()

f_int = models.IntegerField(db_column='dbc_int')
f_foreign = models.ForeignKey(
DotQueryForeignModel,
null=True,
blank=True,
db_column='dbc_foreign'
)


class DotQueryTestModel(models.Model):
objects = MongoDBManager()

f_id = models.IntegerField()
f_dict = DictField(db_column='dbc_dict')
f_list = ListField(db_column='dbc_list')
f_embedded = EmbeddedModelField(
DotQueryEmbeddedModel,
db_column='dbc_embedded',
)
f_embedded_list = ListField(
EmbeddedModelField(
DotQueryEmbeddedModel,
db_column='dbc_embedded',
),
db_column='dbc_embedded_list',
)
112 changes: 112 additions & 0 deletions tests/dotquery/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import with_statement
from django.db.models import Q
from models import *
from utils import *


class DotQueryTests(TestCase):
"""Tests for querying on foo.bar using join syntax."""

def setUp(self):
fm = DotQueryForeignModel.objects.create(
f_char='hello',
)
DotQueryTestModel.objects.create(
f_id=51,
f_dict={'numbers': [1, 2, 3], 'letters': 'abc'},
f_list=[{'color': 'red'}, {'color': 'blue'}],
f_embedded=DotQueryEmbeddedModel(f_int=10, f_foreign=fm),
f_embedded_list=[
DotQueryEmbeddedModel(f_int=100),
DotQueryEmbeddedModel(f_int=101),
],
)
DotQueryTestModel.objects.create(
f_id=52,
f_dict={'numbers': [2, 3], 'letters': 'bc'},
f_list=[{'color': 'red'}, {'color': 'green'}],
f_embedded=DotQueryEmbeddedModel(f_int=11),
f_embedded_list=[
DotQueryEmbeddedModel(f_int=110, f_foreign=fm),
DotQueryEmbeddedModel(f_int=111, f_foreign=fm),
],
)
DotQueryTestModel.objects.create(
f_id=53,
f_dict={'numbers': [3, 4], 'letters': 'cd'},
f_list=[{'color': 'yellow'}, {'color': 'orange'}],
f_embedded=DotQueryEmbeddedModel(f_int=12),
f_embedded_list=[
DotQueryEmbeddedModel(f_int=120),
DotQueryEmbeddedModel(f_int=121),
],
)

def tearDown(self):
DotQueryTestModel.objects.all().delete()
DotQueryForeignModel.objects.all().delete()

def test_dict_queries(self):
qs = DotQueryTestModel.objects.filter(f_dict__numbers=2)
self.assertEqual(qs.count(), 2)
self.assertEqual(qs[0].f_id, 51)
self.assertEqual(qs[1].f_id, 52)
qs = DotQueryTestModel.objects.filter(f_dict__letters__contains='b')
self.assertEqual(qs.count(), 2)
self.assertEqual(qs[0].f_id, 51)
self.assertEqual(qs[1].f_id, 52)
qs = DotQueryTestModel.objects.exclude(f_dict__letters__contains='b')
self.assertEqual(qs.count(), 1)
self.assertEqual(qs[0].f_id, 53)
qs = DotQueryTestModel.objects.exclude(f_dict__letters__icontains='B')
self.assertEqual(qs.count(), 1)
self.assertEqual(qs[0].f_id, 53)

def test_list_queries(self):
qs = DotQueryTestModel.objects.filter(f_list__color='red')
qs = qs.exclude(f_list__color='green')
qs = qs.exclude(f_list__color='purple')
self.assertEqual(qs.count(), 1)
self.assertEqual(qs[0].f_id, 51)

def test_embedded_queries(self):
qs = DotQueryTestModel.objects.exclude(f_embedded__f_int__in=[10, 12])
self.assertEqual(qs.count(), 1)
self.assertEqual(qs[0].f_id, 52)

def test_embedded_list_queries(self):
qs = DotQueryTestModel.objects.get(f_embedded_list__f_int=120)
self.assertEqual(qs.f_id, 53)

def test_foreign_queries(self):
fm = DotQueryForeignModel.objects.get(f_char='hello')
qs = DotQueryTestModel.objects.get(f_embedded__f_foreign=fm)
self.assertEqual(qs.f_id, 51)
qs = DotQueryTestModel.objects.get(f_embedded_list__f_foreign=fm)
self.assertEqual(qs.f_id, 52)
qs = DotQueryTestModel.objects.get(f_embedded__f_foreign__pk=fm.pk)
self.assertEqual(qs.f_id, 51)
qs = DotQueryTestModel.objects.get(f_embedded_list__f_foreign__pk__exact=fm.pk)
self.assertEqual(qs.f_id, 52)

def test_q_queries(self):
q = Q(f_dict__numbers=1) | Q(f_dict__numbers=4)
q = q & Q(f_dict__numbers=3)
qs = DotQueryTestModel.objects.filter(q)
self.assertEqual(qs.count(), 2)
self.assertEqual(qs[0].f_id, 51)
self.assertEqual(qs[1].f_id, 53)

def test_save_after_query(self):
qs = DotQueryTestModel.objects.get(f_dict__letters='cd')
self.assertEqual(qs.f_id, 53)
qs.f_id = 1053
qs.clean()
qs.save()
qs = DotQueryTestModel.objects.get(f_dict__letters='cd')
self.assertEqual(qs.f_id, 1053)
qs.f_id = 53
qs.clean()
qs.save()
qs = DotQueryTestModel.objects.get(f_dict__letters='cd')
self.assertEqual(qs.f_id, 53)
35 changes: 35 additions & 0 deletions tests/dotquery/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from django.conf import settings
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None of the code in this file seems to be used anywhere. Am I missing something?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied utils.py from the other tests. It's in aggregations, contrib, embedded, etc... It appears that it sets up a custom version of TestCase with some extra debugging support (settings.DEBUG = True). I didn't go deeply into what it does, but wanted to model my tests after the existing ones.

from django.db import connections
from django.db.models import Model
from django.test import TestCase
from django.utils.unittest import skip


class TestCase(TestCase):

def setUp(self):
super(TestCase, self).setUp()
if getattr(settings, 'TEST_DEBUG', False):
settings.DEBUG = True

def assertEqualLists(self, a, b):
self.assertEqual(list(a), list(b))


def skip_all_except(*tests):

class meta(type):

def __new__(cls, name, bases, dict):
for attr in dict.keys():
if attr.startswith('test_') and attr not in tests:
del dict[attr]
return type.__new__(cls, name, bases, dict)

return meta


def get_collection(model_or_name):
if isinstance(model_or_name, type) and issubclass(model_or_name, Model):
model_or_name = model_or_name._meta.db_table
return connections['default'].get_collection(model_or_name)
1 change: 1 addition & 0 deletions tests/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@
'aggregations',
'contrib',
'storage',
'dotquery',
]