|
20 | 20 | isna,
|
21 | 21 | )
|
22 | 22 | import pandas._testing as tm
|
| 23 | +from pandas.tests.groupby import get_groupby_method_args |
23 | 24 | from pandas.util import _test_decorators as td
|
24 | 25 |
|
25 | 26 |
|
@@ -956,17 +957,95 @@ def test_min_empty_string_dtype(func, string_dtype_no_object):
|
956 | 957 |
|
957 | 958 |
|
958 | 959 | @pytest.mark.parametrize("min_count", [0, 1])
|
959 |
| -def test_string_dtype_empty_sum(string_dtype_no_object, skipna, min_count): |
960 |
| - # https://github.com/pandas-dev/pandas/issues/60229 |
| 960 | +@pytest.mark.parametrize("test_series", [True, False]) |
| 961 | +def test_string_dtype_all_na( |
| 962 | + string_dtype_no_object, reduction_func, skipna, min_count, test_series |
| 963 | +): |
| 964 | + # https://github.com/pandas-dev/pandas/issues/60985 |
| 965 | + if reduction_func == "corrwith": |
| 966 | + # corrwith is deprecated. |
| 967 | + return |
| 968 | + |
961 | 969 | dtype = string_dtype_no_object
|
| 970 | + |
| 971 | + if reduction_func in [ |
| 972 | + "any", |
| 973 | + "all", |
| 974 | + "idxmin", |
| 975 | + "idxmax", |
| 976 | + "mean", |
| 977 | + "median", |
| 978 | + "std", |
| 979 | + "var", |
| 980 | + ]: |
| 981 | + kwargs = {"skipna": skipna} |
| 982 | + elif reduction_func in ["kurt"]: |
| 983 | + kwargs = {"min_count": min_count} |
| 984 | + elif reduction_func in ["count", "nunique", "quantile", "sem", "size"]: |
| 985 | + kwargs = {} |
| 986 | + else: |
| 987 | + kwargs = {"skipna": skipna, "min_count": min_count} |
| 988 | + |
| 989 | + expected_dtype, expected_value = dtype, pd.NA |
| 990 | + if reduction_func in ["all", "any"]: |
| 991 | + expected_dtype = "bool" |
| 992 | + # TODO: For skipna=False, bool(pd.NA) raises; should groupby? |
| 993 | + expected_value = not skipna if reduction_func == "any" else True |
| 994 | + elif reduction_func in ["count", "nunique", "size"]: |
| 995 | + # TODO: Should be more consistent - return Int64 when dtype.na_value is pd.NA? |
| 996 | + if ( |
| 997 | + test_series |
| 998 | + and reduction_func == "size" |
| 999 | + and dtype.storage == "pyarrow" |
| 1000 | + and dtype.na_value is pd.NA |
| 1001 | + ): |
| 1002 | + expected_dtype = "Int64" |
| 1003 | + else: |
| 1004 | + expected_dtype = "int64" |
| 1005 | + expected_value = 1 if reduction_func == "size" else 0 |
| 1006 | + elif reduction_func in ["idxmin", "idxmax"]: |
| 1007 | + expected_dtype, expected_value = "float64", np.nan |
| 1008 | + elif not skipna or min_count > 0: |
| 1009 | + expected_value = pd.NA |
| 1010 | + elif reduction_func == "sum": |
| 1011 | + # https://github.com/pandas-dev/pandas/pull/60936 |
| 1012 | + expected_value = "" |
| 1013 | + |
962 | 1014 | df = DataFrame({"a": ["x"], "b": [pd.NA]}, dtype=dtype)
|
963 |
| - gb = df.groupby("a") |
964 |
| - result = gb.sum(skipna=skipna, min_count=min_count) |
965 |
| - value = "" if skipna and min_count == 0 else pd.NA |
966 |
| - expected = DataFrame( |
967 |
| - {"b": value}, index=pd.Index(["x"], name="a", dtype=dtype), dtype=dtype |
968 |
| - ) |
969 |
| - tm.assert_frame_equal(result, expected) |
| 1015 | + obj = df["b"] if test_series else df |
| 1016 | + args = get_groupby_method_args(reduction_func, obj) |
| 1017 | + gb = obj.groupby(df["a"]) |
| 1018 | + method = getattr(gb, reduction_func) |
| 1019 | + |
| 1020 | + if reduction_func in [ |
| 1021 | + "mean", |
| 1022 | + "median", |
| 1023 | + "kurt", |
| 1024 | + "prod", |
| 1025 | + "quantile", |
| 1026 | + "sem", |
| 1027 | + "skew", |
| 1028 | + "std", |
| 1029 | + "var", |
| 1030 | + ]: |
| 1031 | + msg = f"dtype '{dtype}' does not support operation '{reduction_func}'" |
| 1032 | + with pytest.raises(TypeError, match=msg): |
| 1033 | + method(*args, **kwargs) |
| 1034 | + return |
| 1035 | + elif reduction_func in ["idxmin", "idxmax"] and not skipna: |
| 1036 | + msg = f"{reduction_func} with skipna=False encountered an NA value." |
| 1037 | + with pytest.raises(ValueError, match=msg): |
| 1038 | + method(*args, **kwargs) |
| 1039 | + return |
| 1040 | + |
| 1041 | + result = method(*args, **kwargs) |
| 1042 | + index = pd.Index(["x"], name="a", dtype=dtype) |
| 1043 | + if test_series or reduction_func == "size": |
| 1044 | + name = None if not test_series and reduction_func == "size" else "b" |
| 1045 | + expected = Series(expected_value, index=index, dtype=expected_dtype, name=name) |
| 1046 | + else: |
| 1047 | + expected = DataFrame({"b": expected_value}, index=index, dtype=expected_dtype) |
| 1048 | + tm.assert_equal(result, expected) |
970 | 1049 |
|
971 | 1050 |
|
972 | 1051 | def test_max_nan_bug():
|
|
0 commit comments