From 29e9609a0455a49e656311e44d4825828e948a7e Mon Sep 17 00:00:00 2001 From: Jarek-Rolski Date: Sun, 2 Feb 2025 23:38:24 +0000 Subject: [PATCH] fix DataFrame Pydantic compatibility --- pandera/typing/pandas.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/pandera/typing/pandas.py b/pandera/typing/pandas.py index 8e26962e..1328602e 100644 --- a/pandera/typing/pandas.py +++ b/pandera/typing/pandas.py @@ -29,6 +29,7 @@ SeriesBase, ) from pandera.typing.formats import Formats +from pandera.config import config_context try: from typing import get_args @@ -191,12 +192,28 @@ def _get_schema_model(cls, field): def __get_pydantic_core_schema__( cls, _source_type: Any, _handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: - schema_model = get_args(_source_type)[0] + with config_context(validation_enabled=False): + schema_model = _source_type().__orig_class__.__args__[0] + schema = schema_model.to_schema() + type_map = { + "str": core_schema.str_schema(), + "int64": core_schema.int_schema(), + "float64": core_schema.float_schema(), + "bool": core_schema.bool_schema(), + "datetime64[ns]": core_schema.datetime_schema() + } return core_schema.no_info_plain_validator_function( - functools.partial( + functools.partial( cls.pydantic_validate, schema_model=schema_model, ), + json_schema_input_schema=core_schema.list_schema( + core_schema.typed_dict_schema( + { + i:core_schema.typed_dict_field(type_map[str(j.dtype)]) for i,j in schema.columns.items() + }, + ) + ) ) else: