diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index a45d82f9c..a23e264f1 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -96,6 +96,8 @@ class CoreConfig(TypedDict, total=False): validate_default: bool # used on typed-dicts and arguments populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 + # stop validation on a first error, used with typed-dict + fail_fast: bool # fields related to string fields only str_max_length: int str_min_length: int @@ -1886,6 +1888,7 @@ class DictSchema(TypedDict, total=False): values_schema: CoreSchema # default: AnySchema min_length: int max_length: int + fail_fast: bool strict: bool ref: str metadata: Dict[str, Any] @@ -1898,6 +1901,7 @@ def dict_schema( *, min_length: int | None = None, max_length: int | None = None, + fail_fast: bool | None = None, strict: bool | None = None, ref: str | None = None, metadata: Dict[str, Any] | None = None, @@ -1921,6 +1925,7 @@ def dict_schema( values_schema: The value must be a dict with values that match this schema min_length: The value must be a dict with at least this many items max_length: The value must be a dict with at most this many items + fail_fast: Stop validation on the first error strict: Whether the keys and values should be validated with strict mode ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -1932,6 +1937,7 @@ def dict_schema( values_schema=values_schema, min_length=min_length, max_length=max_length, + fail_fast=fail_fast, strict=strict, ref=ref, metadata=metadata, @@ -2893,6 +2899,7 @@ class TypedDictSchema(TypedDict, total=False): extra_behavior: ExtraBehavior total: bool # default: True populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 + fail_fast: bool # default: False ref: str metadata: Dict[str, Any] serialization: SerSchema @@ -2909,6 +2916,7 @@ def typed_dict_schema( extra_behavior: ExtraBehavior | None = None, total: bool | None = None, populate_by_name: bool | None = None, + fail_fast: bool | None = None, ref: str | None = None, metadata: Dict[str, Any] | None = None, serialization: SerSchema | None = None, @@ -2943,6 +2951,7 @@ class MyTypedDict(TypedDict): extra_behavior: The extra behavior to use for the typed dict total: Whether the typed dict is total, otherwise uses `typed_dict_total` from config populate_by_name: Whether the typed dict should populate by name + fail_fast: Stop validation on the first error serialization: Custom serialization schema """ return _dict_not_none( @@ -2955,6 +2964,7 @@ class MyTypedDict(TypedDict): extra_behavior=extra_behavior, total=total, populate_by_name=populate_by_name, + fail_fast=fail_fast, ref=ref, metadata=metadata, serialization=serialization, diff --git a/src/validators/dict.rs b/src/validators/dict.rs index b65b38fb1..1aff08ab7 100644 --- a/src/validators/dict.rs +++ b/src/validators/dict.rs @@ -21,6 +21,7 @@ pub struct DictValidator { value_validator: Box, min_length: Option, max_length: Option, + fail_fast: bool, name: String, } @@ -53,6 +54,7 @@ impl BuildValidator for DictValidator { value_validator, min_length: schema.get_as(intern!(py, "min_length"))?, max_length: schema.get_as(intern!(py, "max_length"))?, + fail_fast: schema.get_as(intern!(py, "fail_fast"))?.unwrap_or(false), name, } .into()) @@ -78,6 +80,7 @@ impl Validator for DictValidator { input, min_length: self.min_length, max_length: self.max_length, + fail_fast: self.fail_fast, key_validator: &self.key_validator, value_validator: &self.value_validator, state, @@ -94,6 +97,7 @@ struct ValidateToDict<'a, 's, 'py, I: Input<'py> + ?Sized> { input: &'a I, min_length: Option, max_length: Option, + fail_fast: bool, key_validator: &'a CombinedValidator, value_validator: &'a CombinedValidator, state: &'a mut ValidationState<'s, 'py>, @@ -111,6 +115,12 @@ where let mut errors: Vec = Vec::new(); let allow_partial = self.state.allow_partial; + macro_rules! should_fail_fast { + () => { + self.fail_fast && !errors.is_empty() + }; + } + for (_, is_last_partial, item_result) in self.state.enumerate_last_partial(iterator) { self.state.allow_partial = false.into(); let (key, value) = item_result?; @@ -130,6 +140,11 @@ where true => allow_partial, false => false.into(), }; + + if should_fail_fast!() { + break; + } + let output_value = match self.value_validator.validate(self.py, value.borrow_input(), self.state) { Ok(value) => value, Err(ValError::LineErrors(line_errors)) => { @@ -141,6 +156,11 @@ where Err(ValError::Omit) => continue, Err(err) => return Err(err), }; + + if should_fail_fast!() { + break; + } + if let Some(key) = output_key { output.set_item(key, output_value)?; } diff --git a/src/validators/typed_dict.rs b/src/validators/typed_dict.rs index 9c6523189..900ba285b 100644 --- a/src/validators/typed_dict.rs +++ b/src/validators/typed_dict.rs @@ -34,6 +34,7 @@ pub struct TypedDictValidator { extra_behavior: ExtraBehavior, extras_validator: Option>, strict: bool, + fail_fast: bool, loc_by_alias: bool, } @@ -56,6 +57,7 @@ impl BuildValidator for TypedDictValidator { let total = schema_or_config(schema, config, intern!(py, "total"), intern!(py, "typed_dict_total"))?.unwrap_or(true); let populate_by_name = schema_or_config_same(schema, config, intern!(py, "populate_by_name"))?.unwrap_or(false); + let fail_fast = schema_or_config_same(schema, config, intern!(py, "fail_fast"))?.unwrap_or(false); let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?; @@ -129,6 +131,7 @@ impl BuildValidator for TypedDictValidator { extra_behavior, extras_validator, strict, + fail_fast, loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true), } .into()) @@ -174,6 +177,10 @@ impl Validator for TypedDictValidator { let mut fields_set_count: usize = 0; for field in &self.fields { + if self.fail_fast && !errors.is_empty() { + break; + } + let op_key_value = match dict.get_item(&field.lookup_key) { Ok(v) => v, Err(ValError::LineErrors(line_errors)) => { @@ -265,6 +272,7 @@ impl Validator for TypedDictValidator { extra_behavior: ExtraBehavior, partial_last_key: Option, allow_partial: PartialMode, + fail_fast: bool, } impl<'py, Key, Value> ConsumeIterator> for ValidateExtras<'_, '_, 'py> @@ -275,6 +283,10 @@ impl Validator for TypedDictValidator { type Output = ValResult<()>; fn consume_iterator(self, iterator: impl Iterator>) -> ValResult<()> { for item_result in iterator { + if self.fail_fast && !self.errors.is_empty() { + break; + } + let (raw_key, value) = item_result?; let either_str = match raw_key .borrow_input() @@ -354,6 +366,7 @@ impl Validator for TypedDictValidator { extra_behavior: self.extra_behavior, partial_last_key, allow_partial, + fail_fast: self.fail_fast, })??; } diff --git a/tests/validators/test_dict.py b/tests/validators/test_dict.py index 4057ce76e..7443a0c42 100644 --- a/tests/validators/test_dict.py +++ b/tests/validators/test_dict.py @@ -258,3 +258,61 @@ def test_json_dict_complex_key(): assert v.validate_json('{"1+2j": 2, "infj": 4}') == {complex(1, 2): 2, complex(0, float('inf')): 4} with pytest.raises(ValidationError, match='Input should be a valid complex string'): v.validate_json('{"1+2j": 2, "": 4}') == {complex(1, 2): 2, complex(0, float('inf')): 4} + + +@pytest.mark.parametrize( + ('fail_fast', 'expected'), + [ + pytest.param( + True, + [ + { + 'type': 'int_parsing', + 'loc': ('a', '[key]'), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'a', + }, + ], + id='fail_fast', + ), + pytest.param( + False, + [ + { + 'type': 'int_parsing', + 'loc': ('a', '[key]'), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'a', + }, + { + 'type': 'int_parsing', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'b', + }, + { + 'type': 'int_parsing', + 'loc': ('c', '[key]'), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'c', + }, + { + 'type': 'int_parsing', + 'loc': ('c',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'input': 'd', + }, + ], + id='not_fail_fast', + ), + ], +) +def test_dict_fail_fast(fail_fast, expected): + v = SchemaValidator( + {'type': 'dict', 'keys_schema': {'type': 'int'}, 'values_schema': {'type': 'int'}, 'fail_fast': fail_fast} + ) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python({'a': 'b', 'c': 'd'}) + + assert exc_info.value.errors(include_url=False) == expected diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index dc18cd86e..630f4c45b 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -1196,3 +1196,62 @@ def validate(v, info): gc.collect() assert ref() is None + + +@pytest.mark.parametrize( + ('fail_fast', 'expected'), + [ + pytest.param( + True, + [ + { + 'input': 'c', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'type': 'int_parsing', + }, + ], + id='fail_fast', + ), + pytest.param( + False, + [ + { + 'input': 'c', + 'loc': ('a',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'type': 'int_parsing', + }, + { + 'input': 'd', + 'loc': ('b',), + 'msg': 'Input should be a valid integer, unable to parse string as an integer', + 'type': 'int_parsing', + }, + ], + id='not_fail_fast', + ), + ], +) +def test_typed_dict_fail_fast(fail_fast, expected): + v = SchemaValidator( + { + 'type': 'typed-dict', + 'fields': { + 'a': { + 'type': 'typed-dict-field', + 'schema': {'type': 'int'}, + }, + 'b': { + 'type': 'typed-dict-field', + 'schema': {'type': 'int'}, + }, + }, + 'fail_fast': fail_fast, + } + ) + + with pytest.raises(ValidationError) as exc_info: + v.validate_python({'a': 'c', 'b': 'd'}) + + assert exc_info.value.errors(include_url=False) == expected