diff --git a/README.md b/README.md index 4084941..57412f2 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,9 @@ If either `load_by` or `dump_by` are unset, they will follow from `by_value`. Additionally, there is `EnumField.NAME` to be explicit about the load and dump behavior, this is the same as leaving both `by_value` and either `load_by` and/or `dump_by` unset. +If you want to ensure that the `load_by` and `dump_by` behaviour is always the same you can use +the `StrictEnumField`. + ### Custom Error Message A custom error message can be provided via the `error` keyword argument. It can accept three diff --git a/marshmallow_enum/__init__.py b/marshmallow_enum/__init__.py index 66b1b8e..6d79d64 100644 --- a/marshmallow_enum/__init__.py +++ b/marshmallow_enum/__init__.py @@ -115,3 +115,18 @@ def fail(self, key, **kwargs): raise ValidationError(msg) else: raise super(EnumField, self).make_error(key, **kwargs) + + +class StrictEnumField(EnumField): + """ + Like EnumField but will always load and dump using the same behaviour + Ignores any `load_by` or `dump_by` parameters passed to it + """ + + def __init__( + self, enum, by_value=False, error='', *args, **kwargs + ): + + kwargs.pop('load_by', None) + kwargs.pop('dump_by', None) + super(StrictEnumField, self).__init__(enum, by_value, *args, **kwargs) diff --git a/tests/test_enum_field.py b/tests/test_enum_field.py index f8f690c..62d8575 100644 --- a/tests/test_enum_field.py +++ b/tests/test_enum_field.py @@ -10,7 +10,7 @@ import marshmallow from marshmallow import Schema, ValidationError from marshmallow.fields import List -from marshmallow_enum import EnumField +from marshmallow_enum import EnumField, StrictEnumField PY2 = sys.version_info.major == 2 MARSHMALLOW_VERSION_MAJOR = int(marshmallow.__version__.split('.')[0]) @@ -359,3 +359,73 @@ class MyEnumField(EnumField): EnumField(self.UnicodeEnumTester, error='{values}').fail('by_value') assert exc_info.value.messages[0] == self.values + + +class TestStrictEnumFieldByName(object): + + def setup(self): + self.field = StrictEnumField(EnumTester) + + def test_serialize_enum(self): + assert self.field._serialize(EnumTester.one, None, object()) == 'one' + + def test_serialize_none(self): + assert self.field._serialize(None, None, object()) is None + + def test_deserialize_enum(self): + assert self.field._deserialize('one', None, {}) == EnumTester.one + + def test_deserialize_none(self): + assert self.field._deserialize(None, None, {}) is None + + def test_deserialize_nonexistent_member(self): + with pytest.raises(ValidationError): + self.field._deserialize('fred', None, {}) + + +class TestStrictEnumFieldLoadAndDumpByValueIgnored(object): + + def setup(self): + self.field = StrictEnumField( + EnumTester, + load_by=EnumField.VALUE, + dump_by=EnumField.VALUE + ) + + def test_serialize_enum(self): + assert self.field._serialize(EnumTester.one, None, object()) == 'one' + + def test_serialize_none(self): + assert self.field._serialize(None, None, object()) is None + + def test_deserialize_enum(self): + assert self.field._deserialize('one', None, {}) == EnumTester.one + + def test_deserialize_none(self): + assert self.field._deserialize(None, None, {}) is None + + def test_deserialize_nonexistent_member(self): + with pytest.raises(ValidationError): + self.field._deserialize('fred', None, {}) + + +class TestStrictEnumFieldValue(object): + + def test_deserialize_enum(self): + field = StrictEnumField(EnumTester, by_value=True) + + assert field._deserialize(1, None, {}) == EnumTester.one + + def test_serialize_enum(self): + field = EnumField(EnumTester, by_value=True) + assert field._serialize(EnumTester.one, None, object()) == 1 + + def test_serialize_none(self): + field = EnumField(EnumTester, by_value=True) + assert field._serialize(None, None, object()) is None + + def test_deserialize_nonexistent_member(self): + field = EnumField(EnumTester, by_value=True) + + with pytest.raises(ValidationError): + field._deserialize(4, None, {})