diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index f78adf1a..fafdb9b3 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -83,12 +83,17 @@ class StringMethods( ) -> _T_STR: ... @overload def split( - self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ... + self, + pat: str | re.Pattern[str] = ..., + *, + n: int = ..., + expand: Literal[True], + regex: bool = ..., ) -> _T_EXPANDING: ... @overload def split( self, - pat: str = ..., + pat: str | re.Pattern[str] = ..., *, n: int = ..., expand: Literal[False] = ..., @@ -133,11 +138,15 @@ class StringMethods( regex: bool = ..., ) -> _T_BOOL: ... def match( - self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... + self, + pat: str | re.Pattern[str], + case: bool = ..., + flags: int = ..., + na: Any = ..., ) -> _T_BOOL: ... def replace( self, - pat: str, + pat: str | re.Pattern[str], repl: str | Callable[[re.Match[str]], str], n: int = ..., case: bool | None = ..., @@ -180,18 +189,26 @@ class StringMethods( def count(self, pat: str, flags: int = ...) -> _T_INT: ... def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... - def findall(self, pat: str, flags: int = ...) -> _T_LIST_STR: ... + def findall(self, pat: str | re.Pattern[str], flags: int = ...) -> _T_LIST_STR: ... @overload def extract( - self, pat: str, flags: int = ..., *, expand: Literal[True] = ... + self, + pat: str | re.Pattern[str], + flags: int = ..., + *, + expand: Literal[True] = ..., ) -> pd.DataFrame: ... @overload - def extract(self, pat: str, flags: int, expand: Literal[False]) -> _T_OBJECT: ... + def extract( + self, pat: str | re.Pattern[str], flags: int, expand: Literal[False] + ) -> _T_OBJECT: ... @overload def extract( - self, pat: str, flags: int = ..., *, expand: Literal[False] + self, pat: str | re.Pattern[str], flags: int = ..., *, expand: Literal[False] ) -> _T_OBJECT: ... - def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... + def extractall( + self, pat: str | re.Pattern[str], flags: int = ... + ) -> pd.DataFrame: ... def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ... @@ -214,7 +231,11 @@ class StringMethods( def isnumeric(self) -> _T_BOOL: ... def isdecimal(self) -> _T_BOOL: ... def fullmatch( - self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... + self, + pat: str | re.Pattern[str], + case: bool = ..., + flags: int = ..., + na: Any = ..., ) -> _T_BOOL: ... def removeprefix(self, prefix: str) -> _T_STR: ... def removesuffix(self, suffix: str) -> _T_STR: ... diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 649dae78..7ce9ccfd 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import pytest from typing_extensions import assert_type from tests import ( @@ -44,6 +45,7 @@ def test_string_accessors_boolean_series(): _check(assert_type(s.str.endswith("e"), "pd.Series[bool]")) _check(assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]")) _check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]")) + _check(assert_type(s.str.fullmatch(re.compile(r"apple")), "pd.Series[bool]")) _check(assert_type(s.str.isalnum(), "pd.Series[bool]")) _check(assert_type(s.str.isalpha(), "pd.Series[bool]")) _check(assert_type(s.str.isdecimal(), "pd.Series[bool]")) @@ -54,6 +56,7 @@ def test_string_accessors_boolean_series(): _check(assert_type(s.str.istitle(), "pd.Series[bool]")) _check(assert_type(s.str.isupper(), "pd.Series[bool]")) _check(assert_type(s.str.match("pp"), "pd.Series[bool]")) + _check(assert_type(s.str.match(re.compile(r"pp")), "pd.Series[bool]")) def test_string_accessors_boolean_index(): @@ -72,6 +75,7 @@ def test_string_accessors_boolean_index(): _check(assert_type(idx.str.endswith("e"), np_ndarray_bool)) _check(assert_type(idx.str.endswith(("e", "f")), np_ndarray_bool)) _check(assert_type(idx.str.fullmatch("apple"), np_ndarray_bool)) + _check(assert_type(idx.str.fullmatch(re.compile(r"apple")), np_ndarray_bool)) _check(assert_type(idx.str.isalnum(), np_ndarray_bool)) _check(assert_type(idx.str.isalpha(), np_ndarray_bool)) _check(assert_type(idx.str.isdecimal(), np_ndarray_bool)) @@ -82,6 +86,7 @@ def test_string_accessors_boolean_index(): _check(assert_type(idx.str.istitle(), np_ndarray_bool)) _check(assert_type(idx.str.isupper(), np_ndarray_bool)) _check(assert_type(idx.str.match("pp"), np_ndarray_bool)) + _check(assert_type(idx.str.match(re.compile(r"pp")), np_ndarray_bool)) def test_string_accessors_integer_series(): @@ -94,6 +99,10 @@ def test_string_accessors_integer_series(): _check(assert_type(s.str.count("pp"), "pd.Series[int]")) _check(assert_type(s.str.len(), "pd.Series[int]")) + # unlike findall, find doesn't accept a compiled pattern + with pytest.raises(TypeError): + s.str.find(re.compile(r"p")) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + def test_string_accessors_integer_index(): idx = pd.Index(DATA) @@ -105,6 +114,10 @@ def test_string_accessors_integer_index(): _check(assert_type(idx.str.count("pp"), "pd.Index[int]")) _check(assert_type(idx.str.len(), "pd.Index[int]")) + # unlike findall, find doesn't accept a compiled pattern + with pytest.raises(TypeError): + idx.str.find(re.compile(r"p")) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + def test_string_accessors_string_series(): s = pd.Series(DATA) @@ -123,6 +136,9 @@ def test_string_accessors_string_series(): _check(assert_type(s.str.removesuffix("e"), "pd.Series[str]")) _check(assert_type(s.str.repeat(2), "pd.Series[str]")) _check(assert_type(s.str.replace("a", "X"), "pd.Series[str]")) + _check( + assert_type(s.str.replace(re.compile(r"a"), "X", regex=True), "pd.Series[str]") + ) _check(assert_type(s.str.rjust(80), "pd.Series[str]")) _check(assert_type(s.str.rstrip(), "pd.Series[str]")) _check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]")) @@ -158,6 +174,9 @@ def test_string_accessors_string_index(): _check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]")) _check(assert_type(idx.str.repeat(2), "pd.Index[str]")) _check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]")) + _check( + assert_type(idx.str.replace(re.compile(r"a"), "X", regex=True), "pd.Index[str]") + ) _check(assert_type(idx.str.rjust(80), "pd.Index[str]")) _check(assert_type(idx.str.rstrip(), "pd.Index[str]")) _check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]")) @@ -190,29 +209,49 @@ def test_string_accessors_list_series(): s = pd.Series(DATA) _check = functools.partial(check, klass=pd.Series, dtype=list) _check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]")) + _check(assert_type(s.str.findall(re.compile(r"pp")), "pd.Series[list[str]]")) _check(assert_type(s.str.split("a"), "pd.Series[list[str]]")) + _check(assert_type(s.str.split(re.compile(r"a")), "pd.Series[list[str]]")) # GH 194 _check(assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]")) _check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]")) _check(assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]")) + # rsplit doesn't accept compiled pattern + # it doesn't raise at runtime but produces a nan + bad_rsplit_result = s.str.rsplit( + re.compile(r"a") # type: ignore[call-overload] # pyright: ignore[reportArgumentType] + ) + assert bad_rsplit_result.isna().all() + def test_string_accessors_list_index(): idx = pd.Index(DATA) _check = functools.partial(check, klass=pd.Index, dtype=list) _check(assert_type(idx.str.findall("pp"), "pd.Index[list[str]]")) + _check(assert_type(idx.str.findall(re.compile(r"pp")), "pd.Index[list[str]]")) _check(assert_type(idx.str.split("a"), "pd.Index[list[str]]")) + _check(assert_type(idx.str.split(re.compile(r"a")), "pd.Index[list[str]]")) # GH 194 _check(assert_type(idx.str.split("a", expand=False), "pd.Index[list[str]]")) _check(assert_type(idx.str.rsplit("a"), "pd.Index[list[str]]")) _check(assert_type(idx.str.rsplit("a", expand=False), "pd.Index[list[str]]")) + # rsplit doesn't accept compiled pattern + # it doesn't raise at runtime but produces a nan + bad_rsplit_result = idx.str.rsplit( + re.compile(r"a") # type: ignore[call-overload] # pyright: ignore[reportArgumentType] + ) + assert bad_rsplit_result.isna().all() + def test_string_accessors_expanding_series(): s = pd.Series(["a1", "b2", "c3"]) _check = functools.partial(check, klass=pd.DataFrame) _check(assert_type(s.str.extract(r"([ab])?(\d)"), pd.DataFrame)) + _check(assert_type(s.str.extract(re.compile(r"([ab])?(\d)")), pd.DataFrame)) _check(assert_type(s.str.extractall(r"([ab])?(\d)"), pd.DataFrame)) + _check(assert_type(s.str.extractall(re.compile(r"([ab])?(\d)")), pd.DataFrame)) _check(assert_type(s.str.get_dummies(), pd.DataFrame)) _check(assert_type(s.str.partition("p"), pd.DataFrame)) _check(assert_type(s.str.rpartition("p"), pd.DataFrame)) @@ -231,7 +270,15 @@ def test_string_accessors_expanding_index(): # These ones are the odd ones out? check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) + check( + assert_type(idx.str.extractall(re.compile(r"([ab])?(\d)")), pd.DataFrame), + pd.DataFrame, + ) check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) + check( + assert_type(idx.str.extract(re.compile(r"([ab])?(\d)")), pd.DataFrame), + pd.DataFrame, + ) def test_series_overloads_partition():