diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index 49c9e6ad..e21afdcf 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -1,6 +1,5 @@ import json -from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value from django.db.models.fields.mixins import CheckFieldDefaultMixin @@ -10,6 +9,7 @@ from ..forms import SimpleArrayField from ..query_utils import process_lhs, process_rhs from ..utils import prefix_validation_error +from ..validators import ArrayMaxLengthValidator, LengthValidator __all__ = ["ArrayField"] @@ -27,13 +27,19 @@ class ArrayField(CheckFieldDefaultMixin, Field): } _default_hint = ("list", "[]") - def __init__(self, base_field, size=None, **kwargs): + def __init__(self, base_field, max_size=None, size=None, **kwargs): self.base_field = base_field + self.max_size = max_size self.size = size + if self.max_size: + self.default_validators = [ + *self.default_validators, + ArrayMaxLengthValidator(self.max_size), + ] if self.size: self.default_validators = [ *self.default_validators, - ArrayMaxLengthValidator(self.size), + LengthValidator(self.size), ] # For performance, only add a from_db_value() method if the base field # implements it. @@ -98,6 +104,14 @@ def check(self, **kwargs): id="django_mongodb_backend.array.W004", ) ) + if self.size and self.max_size: + errors.append( + checks.Error( + "ArrayField cannot specify both size and max_size.", + obj=self, + id="django_mongodb_backend.array.E003", + ) + ) return errors def set_attributes_from_name(self, name): @@ -124,12 +138,11 @@ def deconstruct(self): name, path, args, kwargs = super().deconstruct() if path == "django_mongodb_backend.fields.array.ArrayField": path = "django_mongodb_backend.fields.ArrayField" - kwargs.update( - { - "base_field": self.base_field.clone(), - "size": self.size, - } - ) + kwargs["base_field"] = self.base_field.clone() + if self.max_size is not None: + kwargs["max_size"] = self.max_size + if self.size is not None: + kwargs["size"] = self.size return name, path, args, kwargs def to_python(self, value): @@ -213,7 +226,8 @@ def formfield(self, **kwargs): **{ "form_class": SimpleArrayField, "base_field": self.base_field.formfield(), - "max_length": self.size, + "max_length": self.max_size, + "length": self.size, **kwargs, } ) diff --git a/django_mongodb_backend/forms/fields/array.py b/django_mongodb_backend/forms/fields/array.py index 0de48dff..854508cc 100644 --- a/django_mongodb_backend/forms/fields/array.py +++ b/django_mongodb_backend/forms/fields/array.py @@ -2,11 +2,11 @@ from itertools import chain from django import forms -from django.core.exceptions import ValidationError +from django.core.exceptions import ImproperlyConfigured, ValidationError from django.utils.translation import gettext_lazy as _ from ...utils import prefix_validation_error -from ...validators import ArrayMaxLengthValidator, ArrayMinLengthValidator +from ...validators import ArrayMaxLengthValidator, ArrayMinLengthValidator, LengthValidator class SimpleArrayField(forms.CharField): @@ -14,16 +14,26 @@ class SimpleArrayField(forms.CharField): "item_invalid": _("Item %(nth)s in the array did not validate:"), } - def __init__(self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs): + def __init__( + self, base_field, *, delimiter=",", max_length=None, min_length=None, length=None, **kwargs + ): self.base_field = base_field self.delimiter = delimiter super().__init__(**kwargs) + if (min_length is not None or max_length is not None) and length is not None: + invalid_param = "max_length" if max_length is not None else "min_length" + raise ImproperlyConfigured( + f"The length and {invalid_param} parameters are mutually exclusive." + ) if min_length is not None: self.min_length = min_length self.validators.append(ArrayMinLengthValidator(int(min_length))) if max_length is not None: self.max_length = max_length self.validators.append(ArrayMaxLengthValidator(int(max_length))) + if length is not None: + self.length = length + self.validators.append(LengthValidator(int(length))) def clean(self, value): value = super().clean(value) diff --git a/django_mongodb_backend/validators.py b/django_mongodb_backend/validators.py index 6005152e..5ca6cbe2 100644 --- a/django_mongodb_backend/validators.py +++ b/django_mongodb_backend/validators.py @@ -1,4 +1,5 @@ -from django.core.validators import MaxLengthValidator, MinLengthValidator +from django.core.validators import BaseValidator, MaxLengthValidator, MinLengthValidator +from django.utils.deconstruct import deconstructible from django.utils.translation import ngettext_lazy @@ -16,3 +17,19 @@ class ArrayMinLengthValidator(MinLengthValidator): "List contains %(show_value)d items, it should contain no fewer than %(limit_value)d.", "show_value", ) + + +@deconstructible +class LengthValidator(BaseValidator): + message = ngettext_lazy( + "List contains %(show_value)d item, it should contain %(limit_value)d.", + "List contains %(show_value)d items, it should contain %(limit_value)d.", + "show_value", + ) + code = "length" + + def compare(self, a, b): + return a != b + + def clean(self, x): + return len(x) diff --git a/docs/source/ref/forms.rst b/docs/source/ref/forms.rst index 64c42755..934af20e 100644 --- a/docs/source/ref/forms.rst +++ b/docs/source/ref/forms.rst @@ -33,7 +33,7 @@ Stores an :class:`~bson.objectid.ObjectId`. ``SimpleArrayField`` -------------------- -.. class:: SimpleArrayField(base_field, delimiter=',', max_length=None, min_length=None) +.. class:: SimpleArrayField(base_field, delimiter=',', length=None, max_length=None, min_length=None) A field which maps to an array. It is represented by an HTML ````. @@ -91,6 +91,14 @@ Stores an :class:`~bson.objectid.ObjectId`. in cases where the delimiter is a valid character in the underlying field. The delimiter does not need to be only one character. + .. attribute:: length + + This is an optional argument which validates that the array contains + the stated number of items. + + ``length`` may not be specified along with ``max_length`` or + ``min_length``. + .. attribute:: max_length This is an optional argument which validates that the array does not diff --git a/docs/source/ref/models/fields.rst b/docs/source/ref/models/fields.rst index 83bd4848..47a3149c 100644 --- a/docs/source/ref/models/fields.rst +++ b/docs/source/ref/models/fields.rst @@ -8,13 +8,12 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``. ``ArrayField`` -------------- -.. class:: ArrayField(base_field, size=None, **options) +.. class:: ArrayField(base_field, max_size=None, size=None, **options) A field for storing lists of data. Most field types can be used, and you - pass another field instance as the :attr:`base_field - `. You may also specify a :attr:`size - `. ``ArrayField`` can be nested to store multi-dimensional - arrays. + pass another field instance as the :attr:`~ArrayField.base_field`. You may + also specify a :attr:`~ArrayField.size` or :attr:`~ArrayField.max_size`. + ``ArrayField`` can be nested to store multi-dimensional arrays. If you give the field a :attr:`~django.db.models.Field.default`, ensure it's a callable such as ``list`` (for an empty default) or a callable that @@ -59,12 +58,21 @@ Some MongoDB-specific fields are available in ``django_mongodb_backend.fields``. of data and configuration, and serialization are all delegated to the underlying base field. - .. attribute:: size + .. attribute:: max_size This is an optional argument. If passed, the array will have a maximum size as specified, validated - only by forms. + by forms and model validation, but not enforced by the database. + + The ``max_size`` and ``size`` options are mutually exclusive. + + .. attribute:: size + + This is an optional argument. + + If passed, the array will have size as specified, validated by forms + and model validation, but not enforced by the database. Querying ``ArrayField`` ~~~~~~~~~~~~~~~~~~~~~~~ @@ -168,8 +176,8 @@ Index transforms ^^^^^^^^^^^^^^^^ Index transforms index into the array. Any non-negative integer can be used. -There are no errors if it exceeds the :attr:`size ` of the -array. The lookups available after the transform are those from the +There are no errors if it exceeds the :attr:`max_size ` of +the array. The lookups available after the transform are those from the :attr:`base_field `. For example: .. code-block:: pycon diff --git a/docs/source/releases/5.1.x.rst b/docs/source/releases/5.1.x.rst index 02c2019d..2e36e6d8 100644 --- a/docs/source/releases/5.1.x.rst +++ b/docs/source/releases/5.1.x.rst @@ -7,6 +7,10 @@ Django MongoDB Backend 5.1.x *Unreleased* +- Backward-incompatible: :class:`~django_mongodb_backend.fields.ArrayField`\'s + :attr:`~.ArrayField.size` parameter is renamed to + :attr:`~.ArrayField.max_size`. The :attr:`~.ArrayField.size` parameter is now + used to enforce fixed-length arrays. - Added support for :doc:`database caching `. - Fixed ``QuerySet.raw_aggregate()`` field initialization when the document key order doesn't match the order of the model's fields. diff --git a/tests/forms_tests_/test_array.py b/tests/forms_tests_/test_array.py index 58ab5566..360a68c4 100644 --- a/tests/forms_tests_/test_array.py +++ b/tests/forms_tests_/test_array.py @@ -21,21 +21,15 @@ def test_valid(self): def test_to_python_fail(self): field = SimpleArrayField(forms.IntegerField()) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "Item 1 in the array did not validate: Enter a whole number." + with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean("a,b,9") - self.assertEqual( - cm.exception.messages[0], - "Item 1 in the array did not validate: Enter a whole number.", - ) def test_validate_fail(self): field = SimpleArrayField(forms.CharField(required=True)) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "Item 3 in the array did not validate: This field is required." + with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean("a,b,") - self.assertEqual( - cm.exception.messages[0], - "Item 3 in the array did not validate: This field is required.", - ) def test_validate_fail_base_field_error_params(self): field = SimpleArrayField(forms.CharField(max_length=2)) @@ -68,12 +62,9 @@ def test_validate_fail_base_field_error_params(self): def test_validators_fail(self): field = SimpleArrayField(forms.RegexField("[a-e]{2}")) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "Item 1 in the array did not validate: Enter a valid value." + with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean("a,bc,de") - self.assertEqual( - cm.exception.messages[0], - "Item 1 in the array did not validate: Enter a valid value.", - ) def test_delimiter(self): field = SimpleArrayField(forms.CharField(), delimiter="|") @@ -92,21 +83,15 @@ def test_prepare_value(self): def test_max_length(self): field = SimpleArrayField(forms.CharField(), max_length=2) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "List contains 3 items, it should contain no more than 2." + with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean("a,b,c") - self.assertEqual( - cm.exception.messages[0], - "List contains 3 items, it should contain no more than 2.", - ) def test_min_length(self): field = SimpleArrayField(forms.CharField(), min_length=4) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "List contains 3 items, it should contain no fewer than 4." + with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean("a,b,c") - self.assertEqual( - cm.exception.messages[0], - "List contains 3 items, it should contain no fewer than 4.", - ) def test_min_length_singular(self): field = SimpleArrayField(forms.IntegerField(), min_length=2) @@ -115,11 +100,34 @@ def test_min_length_singular(self): with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean([1]) + def test_size_length(self): + field = SimpleArrayField(forms.CharField(max_length=27), length=4) + msg = "List contains 3 items, it should contain 4." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean(["a", "b", "c"]) + msg = "List contains 5 items, it should contain 4." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean(["a", "b", "c", "d", "e"]) + + def test_size_length_singular(self): + field = SimpleArrayField(forms.CharField(max_length=27), length=4) + msg = "List contains 1 item, it should contain 4." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean(["a"]) + def test_required(self): field = SimpleArrayField(forms.CharField(), required=True) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "This field is required." + with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean("") - self.assertEqual(cm.exception.messages[0], "This field is required.") + + def test_length_and_max_min_length(self): + msg = "The length and max_length parameters are mutually exclusive." + with self.assertRaisesMessage(exceptions.ImproperlyConfigured, msg): + SimpleArrayField(forms.CharField(), max_length=3, length=2) + msg = "The length and min_length parameters are mutually exclusive." + with self.assertRaisesMessage(exceptions.ImproperlyConfigured, msg): + SimpleArrayField(forms.CharField(), min_length=3, length=2) def test_model_field_formfield(self): model_field = ArrayField(models.CharField(max_length=27)) @@ -128,11 +136,17 @@ def test_model_field_formfield(self): self.assertIsInstance(form_field.base_field, forms.CharField) self.assertEqual(form_field.base_field.max_length, 27) + def test_model_field_formfield_max_size(self): + model_field = ArrayField(models.CharField(max_length=27), max_size=4) + form_field = model_field.formfield() + self.assertIsInstance(form_field, SimpleArrayField) + self.assertEqual(form_field.max_length, 4) + def test_model_field_formfield_size(self): model_field = ArrayField(models.CharField(max_length=27), size=4) form_field = model_field.formfield() self.assertIsInstance(form_field, SimpleArrayField) - self.assertEqual(form_field.max_length, 4) + self.assertEqual(form_field.length, 4) def test_model_field_choices(self): model_field = ArrayField(models.IntegerField(choices=((1, "A"), (2, "B")))) diff --git a/tests/forms_tests_/test_objectidfield.py b/tests/forms_tests_/test_objectidfield.py index ee37401b..64a46193 100644 --- a/tests/forms_tests_/test_objectidfield.py +++ b/tests/forms_tests_/test_objectidfield.py @@ -23,9 +23,8 @@ def test_clean_empty_string(self): def test_clean_invalid(self): field = ObjectIdField() - with self.assertRaises(ValidationError) as cm: + with self.assertRaisesMessage(ValidationError, "Enter a valid Object Id."): field.clean("invalid") - self.assertEqual(cm.exception.messages[0], "Enter a valid Object Id.") def test_prepare_value(self): field = ObjectIdField() diff --git a/tests/model_fields_/array_default_migrations/0001_initial.py b/tests/model_fields_/array_default_migrations/0001_initial.py index 4faaed19..0c759c05 100644 --- a/tests/model_fields_/array_default_migrations/0001_initial.py +++ b/tests/model_fields_/array_default_migrations/0001_initial.py @@ -21,7 +21,7 @@ class Migration(migrations.Migration): ), ( "field", - django_mongodb_backend.fields.ArrayField(models.IntegerField(), size=None), + django_mongodb_backend.fields.ArrayField(models.IntegerField()), ), ], options={}, diff --git a/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py b/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py index 90f49499..18bfbf99 100644 --- a/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py +++ b/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py @@ -12,9 +12,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="integerarraydefaultmodel", name="field_2", - field=django_mongodb_backend.fields.ArrayField( - models.IntegerField(), default=[], size=None - ), + field=django_mongodb_backend.fields.ArrayField(models.IntegerField(), default=[]), preserve_default=False, ), ] diff --git a/tests/model_fields_/array_index_migrations/0001_initial.py b/tests/model_fields_/array_index_migrations/0001_initial.py index a32b0529..4113bdd4 100644 --- a/tests/model_fields_/array_index_migrations/0001_initial.py +++ b/tests/model_fields_/array_index_migrations/0001_initial.py @@ -22,7 +22,7 @@ class Migration(migrations.Migration): ( "char", django_mongodb_backend.fields.ArrayField( - models.CharField(max_length=10), db_index=True, size=100 + models.CharField(max_length=10), db_index=True, max_size=100 ), ), ("char2", models.CharField(max_length=11, db_index=True)), diff --git a/tests/model_fields_/test_arrayfield.py b/tests/model_fields_/test_arrayfield.py index 3b3c4c0c..2b457201 100644 --- a/tests/model_fields_/test_arrayfield.py +++ b/tests/model_fields_/test_arrayfield.py @@ -85,15 +85,16 @@ class MyModel(models.Model): def test_deconstruct(self): field = ArrayField(models.IntegerField()) name, path, args, kwargs = field.deconstruct() + self.assertEqual(kwargs.keys(), {"base_field"}) new = ArrayField(*args, **kwargs) self.assertEqual(type(new.base_field), type(field.base_field)) self.assertIsNot(new.base_field, field.base_field) - def test_deconstruct_with_size(self): - field = ArrayField(models.IntegerField(), size=3) + def test_deconstruct_with_max_size(self): + field = ArrayField(models.IntegerField(), max_size=3) name, path, args, kwargs = field.deconstruct() new = ArrayField(*args, **kwargs) - self.assertEqual(new.size, field.size) + self.assertEqual(new.max_size, field.max_size) def test_deconstruct_args(self): field = ArrayField(models.CharField(max_length=20)) @@ -645,6 +646,15 @@ class MyModel(models.Model): self.assertEqual(len(errors), 1) self.assertEqual(errors[0].id, "django_mongodb_backend.array.E002") + def test_both_size_and_max_size(self): + class MyModel(models.Model): + field = ArrayField(models.CharField(max_length=3), size=3, max_size=4) + + model = MyModel() + errors = model.check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.array.E003") + def test_invalid_default(self): class MyModel(models.Model): field = ArrayField(models.IntegerField(), default=[]) @@ -722,7 +732,7 @@ class MigrationsTests(TransactionTestCase): ) def test_adding_field_with_default(self): class IntegerArrayDefaultModel(models.Model): - field = ArrayField(models.IntegerField(), size=None) + field = ArrayField(models.IntegerField()) table_name = "model_fields__integerarraydefaultmodel" self.assertNotIn(table_name, connection.introspection.table_names(None)) @@ -734,8 +744,8 @@ class IntegerArrayDefaultModel(models.Model): call_command("migrate", "model_fields_", "0002", verbosity=0) class UpdatedIntegerArrayDefaultModel(models.Model): - field = ArrayField(models.IntegerField(), size=None) - field_2 = ArrayField(models.IntegerField(), default=[], size=None) + field = ArrayField(models.IntegerField()) + field_2 = ArrayField(models.IntegerField(), default=[]) class Meta: db_table = "model_fields__integerarraydefaultmodel" @@ -787,43 +797,51 @@ def test_loading(self): class ValidationTests(SimpleTestCase): def test_unbounded(self): field = ArrayField(models.IntegerField()) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "Item 2 in the array did not validate: This field cannot be null." + with self.assertRaisesMessage(exceptions.ValidationError, msg) as cm: field.clean([1, None], None) self.assertEqual(cm.exception.code, "item_invalid") - self.assertEqual( - cm.exception.message % cm.exception.params, - "Item 2 in the array did not validate: This field cannot be null.", - ) def test_blank_true(self): field = ArrayField(models.IntegerField(blank=True, null=True)) # This should not raise a validation error field.clean([1, None], None) - def test_with_size(self): - field = ArrayField(models.IntegerField(), size=3) + def test_with_max_size(self): + field = ArrayField(models.IntegerField(), max_size=3) field.clean([1, 2, 3], None) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "List contains 4 items, it should contain no more than 3." + with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean([1, 2, 3, 4], None) - self.assertEqual( - cm.exception.messages[0], - "List contains 4 items, it should contain no more than 3.", - ) - def test_with_size_singular(self): - field = ArrayField(models.IntegerField(), size=1) + def test_with_max_size_singular(self): + field = ArrayField(models.IntegerField(), max_size=1) field.clean([1], None) msg = "List contains 2 items, it should contain no more than 1." with self.assertRaisesMessage(exceptions.ValidationError, msg): field.clean([1, 2], None) + def test_with_size(self): + field = ArrayField(models.IntegerField(), size=3) + field.clean([1, 2, 3], None) + msg = "List contains 4 items, it should contain 3." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean([1, 2, 3, 4], None) + + def test_with_size_singular(self): + field = ArrayField(models.IntegerField(), size=2) + field.clean([1, 2], None) + msg = "List contains 1 item, it should contain 2." + with self.assertRaisesMessage(exceptions.ValidationError, msg): + field.clean([1], None) + def test_nested_array_mismatch(self): field = ArrayField(ArrayField(models.IntegerField())) field.clean([[1, 2], [3, 4]], None) - with self.assertRaises(exceptions.ValidationError) as cm: + msg = "Nested arrays must have the same length." + with self.assertRaisesMessage(exceptions.ValidationError, msg) as cm: field.clean([[1, 2], [3, 4, 5]], None) self.assertEqual(cm.exception.code, "nested_array_mismatch") - self.assertEqual(cm.exception.messages[0], "Nested arrays must have the same length.") def test_with_base_field_error_params(self): field = ArrayField(models.CharField(max_length=2)) diff --git a/tests/model_fields_/test_objectidfield.py b/tests/model_fields_/test_objectidfield.py index 13356522..b4b00f01 100644 --- a/tests/model_fields_/test_objectidfield.py +++ b/tests/model_fields_/test_objectidfield.py @@ -117,12 +117,10 @@ def test_loading(self): class ValidationTests(TestCase): def test_invalid_objectid(self): field = ObjectIdField() - with self.assertRaises(ValidationError) as cm: + msg = "“550e8400” is not a valid Object Id." + with self.assertRaisesMessage(ValidationError, msg) as cm: field.clean("550e8400", None) self.assertEqual(cm.exception.code, "invalid") - self.assertEqual( - cm.exception.message % cm.exception.params, "“550e8400” is not a valid Object Id." - ) def test_objectid_instance_ok(self): value = ObjectId() diff --git a/tests/validators_/__init__.py b/tests/validators_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/validators_/tests.py b/tests/validators_/tests.py new file mode 100644 index 00000000..09a54949 --- /dev/null +++ b/tests/validators_/tests.py @@ -0,0 +1,31 @@ +from django.core.exceptions import ValidationError +from django.test import SimpleTestCase + +from django_mongodb_backend.validators import LengthValidator + + +class TestLengthValidator(SimpleTestCase): + validator = LengthValidator(10) + + def test_empty(self): + msg = "List contains 0 items, it should contain 10." + with self.assertRaisesMessage(ValidationError, msg): + self.validator([]) + + def test_singular(self): + msg = "List contains 1 item, it should contain 10." + with self.assertRaisesMessage(ValidationError, msg): + self.validator([1]) + + def test_too_short(self): + msg = "List contains 9 items, it should contain 10." + with self.assertRaisesMessage(ValidationError, msg): + self.validator([1, 2, 3, 4, 5, 6, 7, 8, 9]) + + def test_too_long(self): + msg = "List contains 11 items, it should contain 10." + with self.assertRaisesMessage(ValidationError, msg): + self.validator(list(range(11))) + + def test_valid(self): + self.assertEqual(self.validator(list(range(10))), None)