diff --git a/AUTHORS.rst b/AUTHORS.rst index 7a3bd943..aad296d0 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -21,6 +21,7 @@ Contributions by * Sabin Iacob (https://github.com/m0n5t3r) * kryton (https://github.com/kryton) * Brandon Pedersen (https://github.com/bpedman) +* Brian Gontowski (https://github.com/Molanda) (For an up-to-date list of contributors, see https://github.com/django-mongodb-engine/mongodb-engine/contributors.) diff --git a/django_mongodb_engine/compiler.py b/django_mongodb_engine/compiler.py index a26c0af1..b4fe2f06 100644 --- a/django_mongodb_engine/compiler.py +++ b/django_mongodb_engine/compiler.py @@ -383,6 +383,9 @@ def insert(self, docs, return_id=False): doc.clear() else: raise DatabaseError("Can't save entity with _id set to None") + for d in doc.keys(): + if '.' in d: + del doc[d] collection = self.get_collection() options = self.connection.operation_flags.get('save', {}) diff --git a/django_mongodb_engine/contrib/__init__.py b/django_mongodb_engine/contrib/__init__.py index acd3223b..d45bf42f 100644 --- a/django_mongodb_engine/contrib/__init__.py +++ b/django_mongodb_engine/contrib/__init__.py @@ -1,11 +1,24 @@ import sys +import re +import copy +import django from django.db import models, connections from django.db.models.query import QuerySet from django.db.models.sql.query import Query as SQLQuery +from django.db.models.query_utils import Q +from django_mongodb_engine.compiler import OPERATORS_MAP, NEGATED_OPERATORS_MAP +from djangotoolbox.fields import AbstractIterableField + +if django.VERSION >= (1, 5): + from django.db.models.constants import LOOKUP_SEP +else: + from django.db.models.sql.constants import LOOKUP_SEP ON_PYPY = hasattr(sys, 'pypy_version_info') +ALL_OPERATORS = dict(list(OPERATORS_MAP.items() + NEGATED_OPERATORS_MAP.items())).keys() +MONGO_DOT_FIELDS = ('DictField', 'ListField', 'SetField', 'EmbeddedModelField') def _compiler_for_queryset(qs, which='SQLCompiler'): @@ -85,6 +98,113 @@ def __repr__(self): class MongoDBQuerySet(QuerySet): + def _filter_or_exclude(self, negate, *args, **kwargs): + if args or kwargs: + assert self.query.can_filter(), \ + 'Cannot filter a query once a slice has been taken.' + + clone = self._clone() + clone._process_arg_filters(args, kwargs) + if negate: + clone.query.add_q(~Q(*args, **kwargs)) + else: + clone.query.add_q(Q(*args, **kwargs)) + return clone + + def _get_mongo_field_names(self): + if not hasattr(self, '_mongo_field_names'): + self._mongo_field_names = [] + for name in self.model._meta.get_all_field_names(): + field = self.model._meta.get_field_by_name(name)[0] + if '.' not in name and field.get_internal_type() in MONGO_DOT_FIELDS: + self._mongo_field_names.append(name) + + return self._mongo_field_names + + def _process_arg_filters(self, args, kwargs): + for key, val in kwargs.items(): + del kwargs[key] + key = self._maybe_add_dot_field(key) + kwargs[key] = val + + for a in args: + if isinstance(a, Q): + self._process_q_filters(a) + + def _process_q_filters(self, q): + for c in range(len(q.children)): + child = q.children[c] + if isinstance(child, Q): + self._process_q_filters(child) + elif isinstance(child, tuple): + key, val = child + key = self._maybe_add_dot_field(key) + q.children[c] = (key, val) + + def _maybe_add_dot_field(self, name): + if LOOKUP_SEP in name and name.split(LOOKUP_SEP)[0] in self._get_mongo_field_names(): + for op in ALL_OPERATORS: + if name.endswith(LOOKUP_SEP + op): + name = re.sub(LOOKUP_SEP + op + '$', '#' + op, name) + break + name = name.replace(LOOKUP_SEP, '.').replace('#', LOOKUP_SEP) + + parts1 = name.split(LOOKUP_SEP) + if '.' in parts1[0] and parts1[0] not in self.model._meta.get_all_field_names(): + parts2 = parts1[0].split('.') + parts3 = [] + parts4 = [] + model = self.model + + while len(parts2) > 0: + part = parts2.pop(0) + field = model._meta.get_field_by_name(part)[0] + field_type = field.get_internal_type() + column = field.db_column + if column: + part = column + parts3.append(part) + if field_type == 'ListField': + list_type = field.item_field.get_internal_type() + if list_type == 'EmbeddedModelField': + field = field.item_field + field_type = list_type + if field_type == 'EmbeddedModelField': + model = field.embedded_model() + else: + while len(parts2) > 0: + part = parts2.pop(0) + if field_type in MONGO_DOT_FIELDS: + parts3.append(part) + else: + parts4.append(part) + + db_column = '.'.join(parts3) + + if field_type in MONGO_DOT_FIELDS: + field = AbstractIterableField( + db_column=db_column, + blank=True, + null=True, + editable=False, + ) + else: + field = copy.deepcopy(field) + field.name = None + field.db_column = db_column + field.blank = True + field.null = True + field.editable = False + if hasattr(field, '_related_fields'): + delattr(field, '_related_fields') + + parts5 = parts1[0].split('.')[0:len(parts3)] + name = '.'.join(parts5) + self.model.add_to_class(name, field) + name = LOOKUP_SEP.join([name] + parts4 + parts1[1:]) + + return name + def map_reduce(self, *args, **kwargs): """ Performs a Map/Reduce operation on all documents matching the query, diff --git a/tests/dotquery/__init__.py b/tests/dotquery/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dotquery/models.py b/tests/dotquery/models.py new file mode 100644 index 00000000..f4e06640 --- /dev/null +++ b/tests/dotquery/models.py @@ -0,0 +1,40 @@ +from django.db import models +from djangotoolbox.fields import ListField, DictField, EmbeddedModelField +from django_mongodb_engine.contrib import MongoDBManager + + +class DotQueryForeignModel(models.Model): + objects = MongoDBManager() + + f_char = models.CharField(max_length=200, db_column='dbc_char') + + +class DotQueryEmbeddedModel(models.Model): + objects = MongoDBManager() + + f_int = models.IntegerField(db_column='dbc_int') + f_foreign = models.ForeignKey( + DotQueryForeignModel, + null=True, + blank=True, + db_column='dbc_foreign' + ) + + +class DotQueryTestModel(models.Model): + objects = MongoDBManager() + + f_id = models.IntegerField() + f_dict = DictField(db_column='dbc_dict') + f_list = ListField(db_column='dbc_list') + f_embedded = EmbeddedModelField( + DotQueryEmbeddedModel, + db_column='dbc_embedded', + ) + f_embedded_list = ListField( + EmbeddedModelField( + DotQueryEmbeddedModel, + db_column='dbc_embedded', + ), + db_column='dbc_embedded_list', + ) diff --git a/tests/dotquery/tests.py b/tests/dotquery/tests.py new file mode 100644 index 00000000..ca365861 --- /dev/null +++ b/tests/dotquery/tests.py @@ -0,0 +1,112 @@ +from __future__ import with_statement +from django.db.models import Q +from models import * +from utils import * + + +class DotQueryTests(TestCase): + """Tests for querying on foo.bar using join syntax.""" + + def setUp(self): + fm = DotQueryForeignModel.objects.create( + f_char='hello', + ) + DotQueryTestModel.objects.create( + f_id=51, + f_dict={'numbers': [1, 2, 3], 'letters': 'abc'}, + f_list=[{'color': 'red'}, {'color': 'blue'}], + f_embedded=DotQueryEmbeddedModel(f_int=10, f_foreign=fm), + f_embedded_list=[ + DotQueryEmbeddedModel(f_int=100), + DotQueryEmbeddedModel(f_int=101), + ], + ) + DotQueryTestModel.objects.create( + f_id=52, + f_dict={'numbers': [2, 3], 'letters': 'bc'}, + f_list=[{'color': 'red'}, {'color': 'green'}], + f_embedded=DotQueryEmbeddedModel(f_int=11), + f_embedded_list=[ + DotQueryEmbeddedModel(f_int=110, f_foreign=fm), + DotQueryEmbeddedModel(f_int=111, f_foreign=fm), + ], + ) + DotQueryTestModel.objects.create( + f_id=53, + f_dict={'numbers': [3, 4], 'letters': 'cd'}, + f_list=[{'color': 'yellow'}, {'color': 'orange'}], + f_embedded=DotQueryEmbeddedModel(f_int=12), + f_embedded_list=[ + DotQueryEmbeddedModel(f_int=120), + DotQueryEmbeddedModel(f_int=121), + ], + ) + + def tearDown(self): + DotQueryTestModel.objects.all().delete() + DotQueryForeignModel.objects.all().delete() + + def test_dict_queries(self): + qs = DotQueryTestModel.objects.filter(f_dict__numbers=2) + self.assertEqual(qs.count(), 2) + self.assertEqual(qs[0].f_id, 51) + self.assertEqual(qs[1].f_id, 52) + qs = DotQueryTestModel.objects.filter(f_dict__letters__contains='b') + self.assertEqual(qs.count(), 2) + self.assertEqual(qs[0].f_id, 51) + self.assertEqual(qs[1].f_id, 52) + qs = DotQueryTestModel.objects.exclude(f_dict__letters__contains='b') + self.assertEqual(qs.count(), 1) + self.assertEqual(qs[0].f_id, 53) + qs = DotQueryTestModel.objects.exclude(f_dict__letters__icontains='B') + self.assertEqual(qs.count(), 1) + self.assertEqual(qs[0].f_id, 53) + + def test_list_queries(self): + qs = DotQueryTestModel.objects.filter(f_list__color='red') + qs = qs.exclude(f_list__color='green') + qs = qs.exclude(f_list__color='purple') + self.assertEqual(qs.count(), 1) + self.assertEqual(qs[0].f_id, 51) + + def test_embedded_queries(self): + qs = DotQueryTestModel.objects.exclude(f_embedded__f_int__in=[10, 12]) + self.assertEqual(qs.count(), 1) + self.assertEqual(qs[0].f_id, 52) + + def test_embedded_list_queries(self): + qs = DotQueryTestModel.objects.get(f_embedded_list__f_int=120) + self.assertEqual(qs.f_id, 53) + + def test_foreign_queries(self): + fm = DotQueryForeignModel.objects.get(f_char='hello') + qs = DotQueryTestModel.objects.get(f_embedded__f_foreign=fm) + self.assertEqual(qs.f_id, 51) + qs = DotQueryTestModel.objects.get(f_embedded_list__f_foreign=fm) + self.assertEqual(qs.f_id, 52) + qs = DotQueryTestModel.objects.get(f_embedded__f_foreign__pk=fm.pk) + self.assertEqual(qs.f_id, 51) + qs = DotQueryTestModel.objects.get(f_embedded_list__f_foreign__pk__exact=fm.pk) + self.assertEqual(qs.f_id, 52) + + def test_q_queries(self): + q = Q(f_dict__numbers=1) | Q(f_dict__numbers=4) + q = q & Q(f_dict__numbers=3) + qs = DotQueryTestModel.objects.filter(q) + self.assertEqual(qs.count(), 2) + self.assertEqual(qs[0].f_id, 51) + self.assertEqual(qs[1].f_id, 53) + + def test_save_after_query(self): + qs = DotQueryTestModel.objects.get(f_dict__letters='cd') + self.assertEqual(qs.f_id, 53) + qs.f_id = 1053 + qs.clean() + qs.save() + qs = DotQueryTestModel.objects.get(f_dict__letters='cd') + self.assertEqual(qs.f_id, 1053) + qs.f_id = 53 + qs.clean() + qs.save() + qs = DotQueryTestModel.objects.get(f_dict__letters='cd') + self.assertEqual(qs.f_id, 53) diff --git a/tests/dotquery/utils.py b/tests/dotquery/utils.py new file mode 100644 index 00000000..fada71c8 --- /dev/null +++ b/tests/dotquery/utils.py @@ -0,0 +1,35 @@ +from django.conf import settings +from django.db import connections +from django.db.models import Model +from django.test import TestCase +from django.utils.unittest import skip + + +class TestCase(TestCase): + + def setUp(self): + super(TestCase, self).setUp() + if getattr(settings, 'TEST_DEBUG', False): + settings.DEBUG = True + + def assertEqualLists(self, a, b): + self.assertEqual(list(a), list(b)) + + +def skip_all_except(*tests): + + class meta(type): + + def __new__(cls, name, bases, dict): + for attr in dict.keys(): + if attr.startswith('test_') and attr not in tests: + del dict[attr] + return type.__new__(cls, name, bases, dict) + + return meta + + +def get_collection(model_or_name): + if isinstance(model_or_name, type) and issubclass(model_or_name, Model): + model_or_name = model_or_name._meta.db_table + return connections['default'].get_collection(model_or_name) diff --git a/tests/settings/__init__.py b/tests/settings/__init__.py index eb9f2579..b3e4c323 100644 --- a/tests/settings/__init__.py +++ b/tests/settings/__init__.py @@ -17,4 +17,5 @@ 'aggregations', 'contrib', 'storage', + 'dotquery', ]