diff --git a/src/serializers/config.rs b/src/serializers/config.rs index 13a833176..ac9d364d1 100644 --- a/src/serializers/config.rs +++ b/src/serializers/config.rs @@ -14,7 +14,7 @@ use crate::tools::SchemaDict; use super::errors::py_err_se_err; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] #[allow(clippy::struct_field_names)] pub(crate) struct SerializationConfig { pub timedelta_mode: TimedeltaMode, @@ -57,6 +57,15 @@ macro_rules! serialization_mode { $($variant,)* } + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + $(Self::$variant => write!(f, $value),)* + } + + } + } + impl FromStr for $name { type Err = PyErr; diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 1a0405e2c..b209414c1 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -9,7 +9,7 @@ use crate::definitions::{Definitions, DefinitionsBuilder}; use crate::py_gc::PyGcTraverse; pub(crate) use config::BytesMode; -use config::SerializationConfig; +pub(crate) use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; use extra::{CollectWarnings, SerRecursionState, WarningsMode}; pub(crate) use extra::{DuckTypingSerMode, Extra, SerMode, SerializationState}; diff --git a/src/validators/enum_.rs b/src/validators/enum_.rs index 73589d806..21efc3b6b 100644 --- a/src/validators/enum_.rs +++ b/src/validators/enum_.rs @@ -9,10 +9,12 @@ use pyo3::types::{PyDict, PyFloat, PyInt, PyList, PyString, PyType}; use crate::build_tools::{is_strict, py_schema_err}; use crate::errors::{ErrorType, ValError, ValResult}; use crate::input::Input; +use crate::serializers::{to_jsonable_python, SerializationConfig}; use crate::tools::{safe_repr, SchemaDict}; use super::is_instance::class_repr; use super::literal::{expected_repr_name, LiteralLookup}; +use super::InputType; use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, Exactness, ValidationState, Validator}; #[derive(Debug, Clone)] @@ -33,12 +35,38 @@ impl BuildValidator for BuildEnumValidator { let py = schema.py(); let value_str = intern!(py, "value"); - let expected: Vec<(Bound<'_, PyAny>, PyObject)> = members + let expected_py: Vec<(Bound<'_, PyAny>, PyObject)> = members .iter() .map(|v| Ok((v.getattr(value_str)?, v.into()))) .collect::>()?; + let ser_config = SerializationConfig::from_config(config).unwrap_or_default(); + let expected_json: Vec<(Bound<'_, PyAny>, PyObject)> = members + .iter() + .map(|v| { + Ok(( + to_jsonable_python( + py, + &v.getattr(value_str)?, + None, + None, + false, + false, + false, + &ser_config.timedelta_mode.to_string(), + &ser_config.bytes_mode.to_string(), + &ser_config.inf_nan_mode.to_string(), + false, + None, + true, + None, + )? + .into_bound(py), + v.into(), + )) + }) + .collect::>()?; - let repr_args: Vec = expected + let repr_args: Vec = expected_py .iter() .map(|(k, _)| k.repr()?.extract()) .collect::>()?; @@ -46,14 +74,16 @@ impl BuildValidator for BuildEnumValidator { let class: Bound = schema.get_as_req(intern!(py, "cls"))?; let class_repr = class_repr(schema, &class)?; - let lookup = LiteralLookup::new(py, expected.into_iter())?; + let py_lookup = LiteralLookup::new(py, expected_py.into_iter())?; + let json_lookup = LiteralLookup::new(py, expected_json.into_iter())?; macro_rules! build { ($vv:ty, $name_prefix:literal) => { EnumValidator { phantom: PhantomData::<$vv>, class: class.clone().into(), - lookup, + py_lookup, + json_lookup, missing: schema.get_as(intern!(py, "missing"))?, expected_repr: expected_repr_name(repr_args, "").0, strict: is_strict(schema, config)?, @@ -87,7 +117,8 @@ pub trait EnumValidateValue: std::fmt::Debug + Clone + Send + Sync { pub struct EnumValidator { phantom: PhantomData, class: Py, - lookup: LiteralLookup, + py_lookup: LiteralLookup, + json_lookup: LiteralLookup, missing: Option, expected_repr: String, strict: bool, @@ -120,7 +151,11 @@ impl Validator for EnumValidator { state.floor_exactness(Exactness::Lax); - if let Some(v) = T::validate_value(py, input, &self.lookup, strict)? { + let lookup = match state.extra().input_type { + InputType::Json => &self.json_lookup, + _ => &self.py_lookup, + }; + if let Some(v) = T::validate_value(py, input, lookup, strict)? { return Ok(v); } else if let Ok(res) = class.as_unbound().call1(py, (input.as_python(),)) { return Ok(res); diff --git a/tests/validators/test_enums.py b/tests/validators/test_enums.py index 83e286417..fc3a06403 100644 --- a/tests/validators/test_enums.py +++ b/tests/validators/test_enums.py @@ -1,3 +1,4 @@ +import datetime import re import sys from decimal import Decimal @@ -262,6 +263,57 @@ class MyEnum(Enum): assert v.validate_python([2]) is MyEnum.b +def test_plain_enum_tuple(): + class MyEnum(Enum): + a = 1, 2 + b = 2, 3 + + assert MyEnum((1, 2)) is MyEnum.a + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python((1, 2)) is MyEnum.a + assert v.validate_python((2, 3)) is MyEnum.b + assert v.validate_json('[1, 2]') is MyEnum.a + + +def test_plain_enum_datetime(): + class MyEnum(Enum): + a = datetime.datetime.fromisoformat('2024-01-01T00:00:00') + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python(datetime.datetime.fromisoformat('2024-01-01T00:00:00')) is MyEnum.a + assert v.validate_json('"2024-01-01T00:00:00"') is MyEnum.a + + +def test_plain_enum_complex(): + class MyEnum(Enum): + a = complex(1, 2) + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python(complex(1, 2)) is MyEnum.a + assert v.validate_json('"1+2j"') is MyEnum.a + + +def test_plain_enum_identical_serialized_form(): + class MyEnum(Enum): + tuple_ = 1, 2 + list_ = [1, 2] + + v = SchemaValidator(core_schema.enum_schema(MyEnum, list(MyEnum.__members__.values()))) + assert v.validate_python((1, 2)) is MyEnum.tuple_ + assert v.validate_python([1, 2]) is MyEnum.list_ + assert v.validate_json('[1,2]') is MyEnum.tuple_ + + # Change the order of `a` and `b` in MyEnum2; validate_json should pick [1, 2] this time + class MyEnum2(Enum): + list_ = [1, 2] + tuple_ = 1, 2 + + v = SchemaValidator(core_schema.enum_schema(MyEnum2, list(MyEnum2.__members__.values()))) + assert v.validate_python((1, 2)) is MyEnum2.tuple_ + assert v.validate_python([1, 2]) is MyEnum2.list_ + assert v.validate_json('[1,2]') is MyEnum2.list_ + + def test_plain_enum_empty(): class MyEnum(Enum): pass