From f62872b2e1ae6ea37665e0f24f4f9d7e978c5569 Mon Sep 17 00:00:00 2001 From: Joshua Skrzypek Date: Sat, 21 Sep 2024 08:57:13 -0400 Subject: [PATCH] Upgrade multimethod to 1.12 (#1803) * deps: Bump multimethod requirement to 1.12 Signed-off-by: Joshua Skrzypek * Swap out @overload for @multidispatch. This removes the deprecated @overload decorator and replaces it with the @multidispatch decorator, using multimethod.parametric. Signed-off-by: Joshua Skrzypek * lint: Fix lint errors Signed-off-by: Joshua Skrzypek * refactor: Pull out parametric requirement Instead of using multimethod.parametric, we can just use typing Unions, because the predicates that were passed to parametric are just calling isinstance on the type tuples. Signed-off-by: Joshua Skrzypek * Unpin multimethod after taking out parametric dependency. Signed-off-by: Joshua Skrzypek * Appease mypy lint checker There would be cleaner ways to do this, but not in python 3.8 (i.e. no explicit types.UnionType, etc.). Signed-off-by: Joshua Skrzypek --------- Signed-off-by: Joshua Skrzypek --- ...nts-py3.10-pandas1.5.3-pydantic1.10.11.txt | 2 +- ...ments-py3.10-pandas1.5.3-pydantic2.3.0.txt | 2 +- ...nts-py3.10-pandas2.2.2-pydantic1.10.11.txt | 2 +- ...ments-py3.10-pandas2.2.2-pydantic2.3.0.txt | 2 +- ...nts-py3.11-pandas1.5.3-pydantic1.10.11.txt | 2 +- ...ments-py3.11-pandas1.5.3-pydantic2.3.0.txt | 2 +- ...nts-py3.11-pandas2.2.2-pydantic1.10.11.txt | 2 +- ...ments-py3.11-pandas2.2.2-pydantic2.3.0.txt | 2 +- ...ents-py3.9-pandas1.5.3-pydantic1.10.11.txt | 2 +- ...ements-py3.9-pandas1.5.3-pydantic2.3.0.txt | 2 +- ...ents-py3.9-pandas2.2.2-pydantic1.10.11.txt | 2 +- ...ements-py3.9-pandas2.2.2-pydantic2.3.0.txt | 2 +- dev/requirements-3.10.txt | 2 +- dev/requirements-3.11.txt | 2 +- dev/requirements-3.9.txt | 2 +- environment.yml | 4 +- pandera/api/pandas/types.py | 60 +++++++--- pandera/api/polars/types.py | 8 ++ pandera/backends/pandas/checks.py | 103 +++++++++--------- pandera/backends/pandas/hypotheses.py | 20 ++-- pandera/backends/pandas/parsers.py | 28 ++--- pandera/backends/polars/checks.py | 16 +-- reqs-test.txt | 2 +- requirements.in | 2 +- setup.py | 2 +- 25 files changed, 157 insertions(+), 118 deletions(-) diff --git a/ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt b/ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt index 7dc9a5cf8..2aef3ab8b 100644 --- a/ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt +++ b/ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt @@ -80,7 +80,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt b/ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt index 3889e7e98..8a4ea4f2f 100644 --- a/ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt +++ b/ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt b/ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt index 5084b71c4..52ca0e11b 100644 --- a/ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt +++ b/ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt b/ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt index b98e8a73b..1b9be9414 100644 --- a/ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt +++ b/ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt @@ -82,7 +82,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt b/ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt index eea6e24cd..1d6d27a54 100644 --- a/ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt +++ b/ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt @@ -79,7 +79,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt b/ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt index 7acd63154..9751608e9 100644 --- a/ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt +++ b/ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt @@ -80,7 +80,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt b/ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt index 6e75fbaaf..4f5069fff 100644 --- a/ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt +++ b/ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt @@ -80,7 +80,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt b/ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt index 2ec41e6d6..22f1cc434 100644 --- a/ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt +++ b/ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt b/ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt index 14605ecde..d6e94b179 100644 --- a/ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt +++ b/ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt @@ -80,7 +80,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt b/ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt index 94a84daad..3e64e5d34 100644 --- a/ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt +++ b/ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt b/ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt index ab8ce642c..ffaf43d9b 100644 --- a/ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt +++ b/ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt b/ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt index bd84392a7..4167b5015 100644 --- a/ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt +++ b/ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt @@ -82,7 +82,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/dev/requirements-3.10.txt b/dev/requirements-3.10.txt index 817543d9d..fd147196a 100644 --- a/dev/requirements-3.10.txt +++ b/dev/requirements-3.10.txt @@ -82,7 +82,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/dev/requirements-3.11.txt b/dev/requirements-3.11.txt index 1c9bbb43f..47a55b9ab 100644 --- a/dev/requirements-3.11.txt +++ b/dev/requirements-3.11.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/dev/requirements-3.9.txt b/dev/requirements-3.9.txt index 08ace7e87..c3be160e5 100644 --- a/dev/requirements-3.9.txt +++ b/dev/requirements-3.9.txt @@ -82,7 +82,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/environment.yml b/environment.yml index 611dd9d00..1abd2e3b5 100644 --- a/environment.yml +++ b/environment.yml @@ -16,10 +16,10 @@ dependencies: - pyyaml >=5.1 - typing_inspect >= 0.6.0 - typing_extensions >= 3.7.4.3 - - frictionless <= 4.40.8 # v5.* introduces breaking changes + - frictionless <= 4.40.8 # v5.* introduces breaking changes - pyarrow - pydantic - - multimethod <= 1.10.0 + - multimethod # mypy extra - pandas-stubs diff --git a/pandera/api/pandas/types.py b/pandera/api/pandas/types.py index 2006359a9..58fb7a902 100644 --- a/pandera/api/pandas/types.py +++ b/pandera/api/pandas/types.py @@ -32,35 +32,35 @@ def supported_types() -> SupportedTypes: """Get the types supported by pandera schemas.""" # pylint: disable=import-outside-toplevel - table_types = [pd.DataFrame] - field_types = [pd.Series] - index_types = [pd.Index] - multiindex_types = [pd.MultiIndex] + table_types: Tuple[type, ...] = (pd.DataFrame,) + field_types: Tuple[type, ...] = (pd.Series,) + index_types: Tuple[type, ...] = (pd.Index,) + multiindex_types: Tuple[type, ...] = (pd.MultiIndex,) try: import pyspark.pandas as ps - table_types.append(ps.DataFrame) - field_types.append(ps.Series) - index_types.append(ps.Index) - multiindex_types.append(ps.MultiIndex) + table_types += (ps.DataFrame,) + field_types += (ps.Series,) + index_types += (ps.Index,) + multiindex_types += (ps.MultiIndex,) except ImportError: pass try: # pragma: no cover import modin.pandas as mpd - table_types.append(mpd.DataFrame) - field_types.append(mpd.Series) - index_types.append(mpd.Index) - multiindex_types.append(mpd.MultiIndex) + table_types += (mpd.DataFrame,) + field_types += (mpd.Series,) + index_types += (mpd.Index,) + multiindex_types += (mpd.MultiIndex,) except ImportError: pass try: import dask.dataframe as dd - table_types.append(dd.DataFrame) - field_types.append(dd.Series) - index_types.append(dd.Index) + table_types += (dd.DataFrame,) + field_types += (dd.Series,) + index_types += (dd.Index,) except ImportError: pass @@ -72,6 +72,36 @@ def supported_types() -> SupportedTypes: ) +def supported_type_unions(attribute: str): + """Get the type unions for a given attribute.""" + if attribute == "table_types": + return Union[tuple(supported_types().table_types)] + if attribute == "field_types": + return Union[tuple(supported_types().field_types)] + if attribute == "index_types": + return Union[tuple(supported_types().index_types)] + if attribute == "multiindex_types": + return Union[tuple(supported_types().multiindex_types)] + if attribute == "table_or_field_types": + return Union[ + tuple( + ( + *supported_types().table_types, + *supported_types().field_types, + ) + ) + ] + raise ValueError(f"invalid attribute {attribute}") + + +Table = supported_type_unions("table_types") +Field = supported_type_unions("field_types") +Index = supported_type_unions("index_types") +Multiindex = supported_type_unions("multiindex_types") +TableOrField = supported_type_unions("table_or_field_types") +Bool = Union[bool, np.bool_] + + def is_table(obj): """Verifies whether an object is table-like. diff --git a/pandera/api/polars/types.py b/pandera/api/polars/types.py index f038bcf73..a23464f2c 100644 --- a/pandera/api/polars/types.py +++ b/pandera/api/polars/types.py @@ -27,3 +27,11 @@ class CheckResult(NamedTuple): type, pl.datatypes.classes.DataTypeClass, ] + + +def is_bool(x): + """Verifies whether an object is a boolean type.""" + return isinstance(x, (bool, pl.Boolean)) + + +Bool = Union[bool, pl.Boolean] diff --git a/pandera/backends/pandas/checks.py b/pandera/backends/pandas/checks.py index 3c9cc3d61..fc721c82d 100644 --- a/pandera/backends/pandas/checks.py +++ b/pandera/backends/pandas/checks.py @@ -4,15 +4,14 @@ from typing import Dict, List, Optional, Union, cast import pandas as pd -from multimethod import DispatchError, overload - +from multimethod import DispatchError, multidispatch from pandera.api.base.checks import CheckResult, GroupbyObject from pandera.api.checks import Check from pandera.api.pandas.types import ( - is_bool, - is_field, - is_table, - is_table_or_field, + Bool, + Field, + Table, + TableOrField, ) from pandera.backends.base import BaseCheckBackend @@ -78,18 +77,18 @@ def _format_groupby_input( return output # type: ignore[return-value] - @overload + @multidispatch def preprocess(self, check_obj, key) -> pd.Series: """Preprocesses a check object before applying the check function.""" # This handles the case of Series validation, which has no other context except # for the index to groupby on. Right now grouping by the index is not allowed. return check_obj - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_field, # type: ignore [valid-type] - key, + check_obj: Field, # type: ignore [valid-type] + _, ) -> Union[pd.Series, Dict[str, pd.Series]]: if self.check.groupby is None: return check_obj @@ -100,10 +99,10 @@ def preprocess( ), ) - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] key, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: if self.check.groupby is None: @@ -115,11 +114,11 @@ def preprocess( ), ) - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] - key: None, + check_obj: Table, # type: ignore [valid-type] + _: None, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: if self.check.groupby is None: return check_obj @@ -130,39 +129,39 @@ def preprocess( ), ) - @overload + @multidispatch def apply(self, check_obj): """Apply the check function to a check object.""" raise NotImplementedError - @overload # type: ignore [no-redef] - def apply(self, check_obj: dict): + @apply.register + def _(self, check_obj: dict): return self.check_fn(check_obj) - @overload # type: ignore [no-redef] - def apply(self, check_obj: is_field): # type: ignore [valid-type] + @apply.register + def _(self, check_obj: Field): # type: ignore [valid-type] if self.check.element_wise: return check_obj.map(self.check_fn) return self.check_fn(check_obj) - @overload # type: ignore [no-redef] - def apply(self, check_obj: is_table): # type: ignore [valid-type] + @apply.register + def _(self, check_obj: Table): # type: ignore [valid-type] if self.check.element_wise: return check_obj.apply(self.check_fn, axis=1) return self.check_fn(check_obj) - @overload + @multidispatch def postprocess(self, check_obj, check_output): """Postprocesses the result of applying the check function.""" raise TypeError( f"output type of check_fn not recognized: {type(check_output)}" ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, check_obj, - check_output: is_bool, # type: ignore [valid-type] + check_output: Bool, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" return CheckResult( @@ -198,11 +197,11 @@ def _get_series_failure_cases( ) return failure_cases - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, - check_obj: is_field, # type: ignore [valid-type] - check_output: is_field, # type: ignore [valid-type] + check_obj: Field, # type: ignore [valid-type] + check_output: Field, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" if check_obj.index.equals(check_output.index) and self.check.ignore_na: @@ -214,11 +213,11 @@ def postprocess( self._get_series_failure_cases(check_obj, check_output), ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] - check_output: is_field, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] + check_output: Field, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" if check_obj.index.equals(check_output.index) and self.check.ignore_na: @@ -230,11 +229,11 @@ def postprocess( self._get_series_failure_cases(check_obj, check_output), ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] - check_output: is_table, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] + check_output: Table, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" assert check_obj.shape == check_output.shape @@ -244,19 +243,19 @@ def postprocess( # collect failure cases across all columns. Flse values in check_output # are nulls. select_failure_cases = check_obj[~check_output] - failure_cases = [] + _failure_cases = [] for col in select_failure_cases.columns: cases = select_failure_cases[col].rename("failure_case").dropna() if len(cases) == 0: continue - failure_cases.append( + _failure_cases.append( cases.to_frame() .assign(column=col) .rename_axis("index") .reset_index() ) - if failure_cases: - failure_cases = pd.concat(failure_cases, axis=0) + if _failure_cases: + failure_cases = pd.concat(_failure_cases, axis=0) # convert to a dataframe where each row is a failure case at # a particular index, and failure case values are dictionaries # indicating which column and value failed in that row. @@ -279,11 +278,11 @@ def postprocess( failure_cases, ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, - check_obj: is_table_or_field, # type: ignore [valid-type] - check_output: is_bool, # type: ignore [valid-type] + check_obj: TableOrField, # type: ignore [valid-type] + check_output: Bool, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" check_output = bool(check_output) @@ -294,11 +293,11 @@ def postprocess( None, ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, check_obj: dict, - check_output: is_field, # type: ignore [valid-type] + check_output: Field, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" return CheckResult( diff --git a/pandera/backends/pandas/hypotheses.py b/pandera/backends/pandas/hypotheses.py index a9bdc6f8d..2c7b9ebd5 100644 --- a/pandera/backends/pandas/hypotheses.py +++ b/pandera/backends/pandas/hypotheses.py @@ -4,11 +4,11 @@ from typing import Any, Callable, Dict, Union, cast import pandas as pd -from multimethod import overload +from multimethod import multidispatch from pandera import errors from pandera.api.hypotheses import Hypothesis -from pandera.api.pandas.types import is_field, is_table +from pandera.api.pandas.types import is_field, Table from pandera.backends.pandas.checks import PandasCheckBackend @@ -48,6 +48,8 @@ def equal(stat, pvalue, alpha=DEFAULT_ALPHA) -> bool: class PandasHypothesisBackend(PandasCheckBackend): """Hypothesis backend implementation for pandas.""" + check: Hypothesis + RELATIONSHIP_FUNCTIONS = { "greater_than": greater_than, "less_than": less_than, @@ -106,15 +108,15 @@ def is_one_sample_test(self): """Return True if hypothesis is a one-sample test.""" return len(self.check.samples) <= 1 - @overload # type: ignore [no-redef] + @multidispatch def preprocess(self, check_obj, key) -> Any: self.check.groups = self.check.samples return super().preprocess(check_obj, key) - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] key, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: if self.check.groupby is None: @@ -126,10 +128,10 @@ def preprocess( ), ) - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] key: None, ) -> pd.Series: """Preprocesses a check object before applying the check function.""" diff --git a/pandera/backends/pandas/parsers.py b/pandera/backends/pandas/parsers.py index 5d5ac1368..a80517468 100644 --- a/pandera/backends/pandas/parsers.py +++ b/pandera/backends/pandas/parsers.py @@ -4,10 +4,10 @@ from typing import Dict, Optional, Union import pandas as pd -from multimethod import overload +from multimethod import multidispatch from pandera.api.base.parsers import ParserResult -from pandera.api.pandas.types import is_field, is_table +from pandera.api.pandas.types import Field, Table from pandera.api.parsers import Parser from pandera.backends.base import BaseParserBackend @@ -22,40 +22,40 @@ def __init__(self, parser: Parser): self.parser = parser self.parser_fn = partial(parser._parser_fn, **parser._parser_kwargs) - @overload + @multidispatch def preprocess( self, parse_obj, key # pylint:disable=unused-argument ) -> pd.Series: # pylint:disable=unused-argument """Preprocesses a parser object before applying the parse function.""" return parse_obj - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - parse_obj: is_table, # type: ignore [valid-type] + parse_obj: Table, # type: ignore [valid-type] key, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: return parse_obj[key] - @overload # type: ignore [no-redef] - def preprocess( - self, parse_obj: is_table, key: None # type: ignore [valid-type] # pylint:disable=unused-argument + @preprocess.register + def _( + self, parse_obj: Table, key: None # type: ignore [valid-type] # pylint:disable=unused-argument ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: return parse_obj - @overload + @multidispatch def apply(self, parse_obj): """Apply the parse function to a parser object.""" raise NotImplementedError - @overload # type: ignore [no-redef] - def apply(self, parse_obj: is_field): # type: ignore [valid-type] + @apply.register + def _(self, parse_obj: Field): # type: ignore [valid-type] if self.parser.element_wise: return parse_obj.map(self.parser_fn) return self.parser_fn(parse_obj) - @overload # type: ignore [no-redef] - def apply(self, parse_obj: is_table): # type: ignore [valid-type] + @apply.register + def _(self, parse_obj: Table): # type: ignore [valid-type] if self.parser.element_wise: return getattr(parse_obj, "map", parse_obj.applymap)( self.parser_fn diff --git a/pandera/backends/polars/checks.py b/pandera/backends/polars/checks.py index 26b599690..203a7e8df 100644 --- a/pandera/backends/polars/checks.py +++ b/pandera/backends/polars/checks.py @@ -4,12 +4,12 @@ from typing import Optional import polars as pl -from multimethod import overload +from multimethod import multidispatch from polars.lazyframe.group_by import LazyGroupBy from pandera.api.base.checks import CheckResult from pandera.api.checks import Check -from pandera.api.polars.types import PolarsData +from pandera.api.polars.types import PolarsData, Bool from pandera.api.polars.utils import ( get_lazyframe_schema, get_lazyframe_column_names, @@ -76,15 +76,15 @@ def apply(self, check_obj: PolarsData): return out - @overload + @multidispatch def postprocess(self, check_obj, check_output): """Postprocesses the result of applying the check function.""" raise TypeError( # pragma: no cover f"output type of check_fn not recognized: {type(check_output)}" ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, check_obj: PolarsData, check_output: pl.LazyFrame, @@ -105,11 +105,11 @@ def postprocess( failure_cases=failure_cases, ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, check_obj: PolarsData, - check_output: bool, + check_output: Bool, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" ldf_output = pl.LazyFrame({CHECK_OUTPUT_KEY: [check_output]}) diff --git a/reqs-test.txt b/reqs-test.txt index 8e51b5f14..cefc87bfe 100644 --- a/reqs-test.txt +++ b/reqs-test.txt @@ -272,7 +272,7 @@ msgpack==1.0.5 # via # distributed # ray -multimethod==1.9.1 +multimethod==1.12 # via -r requirements.in mypy==0.982 # via -r requirements.in diff --git a/requirements.in b/requirements.in index df53991bf..f4e2217e5 100644 --- a/requirements.in +++ b/requirements.in @@ -14,7 +14,7 @@ typing_extensions >= 3.7.4.3 frictionless <= 4.40.8 pyarrow pydantic -multimethod <= 1.10.0 +multimethod pandas-stubs pyspark[connect] >= 3.2.0 polars >= 0.20.0 diff --git a/setup.py b/setup.py index 749915fe6..d5bff6a07 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ packages=find_packages(include=["pandera*"]), package_data={"pandera": ["py.typed"]}, install_requires=[ - "multimethod <= 1.10.0", + "multimethod", "numpy >= 1.19.0", "packaging >= 20.0", "pandas >= 1.2.0",