Skip to content

Commit 27928ed

Browse files
BUG: fix fill value for gouped sum in case of unobserved categories for string dtype (empty string instead of 0) (#61909)
1 parent 962168f commit 27928ed

File tree

6 files changed

+37
-10
lines changed

6 files changed

+37
-10
lines changed

pandas/_libs/groupby.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def group_sum(
6767
result_mask: np.ndarray | None = ...,
6868
min_count: int = ...,
6969
is_datetimelike: bool = ...,
70+
initial: object = ...,
7071
skipna: bool = ...,
7172
) -> None: ...
7273
def group_prod(

pandas/_libs/groupby.pyx

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,7 @@ def group_sum(
707707
uint8_t[:, ::1] result_mask=None,
708708
Py_ssize_t min_count=0,
709709
bint is_datetimelike=False,
710+
object initial=0,
710711
bint skipna=True,
711712
) -> None:
712713
"""
@@ -725,9 +726,15 @@ def group_sum(
725726
raise ValueError("len(index) != len(labels)")
726727

727728
nobs = np.zeros((<object>out).shape, dtype=np.int64)
728-
# the below is equivalent to `np.zeros_like(out)` but faster
729-
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
730-
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
729+
if initial == 0:
730+
# the below is equivalent to `np.zeros_like(out)` but faster
731+
sumx = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
732+
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
733+
else:
734+
# in practice this path is only taken for strings to use empty string as initial
735+
assert sum_t is object
736+
sumx = np.full((<object>out).shape, initial, dtype=object)
737+
# object code path does not use `compensation`
731738

732739
N, K = (<object>values).shape
733740
if uses_mask:

pandas/core/arrays/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2608,6 +2608,7 @@ def _groupby_op(
26082608
kind = WrappedCythonOp.get_kind_from_how(how)
26092609
op = WrappedCythonOp(how=how, kind=kind, has_dropped_na=has_dropped_na)
26102610

2611+
initial: Any = 0
26112612
# GH#43682
26122613
if isinstance(self.dtype, StringDtype):
26132614
# StringArray
@@ -2632,6 +2633,7 @@ def _groupby_op(
26322633

26332634
arr = self
26342635
if op.how == "sum":
2636+
initial = ""
26352637
# https://github.com/pandas-dev/pandas/issues/60229
26362638
# All NA should result in the empty string.
26372639
assert "skipna" in kwargs
@@ -2649,6 +2651,7 @@ def _groupby_op(
26492651
ngroups=ngroups,
26502652
comp_ids=ids,
26512653
mask=None,
2654+
initial=initial,
26522655
**kwargs,
26532656
)
26542657

pandas/core/groupby/ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import functools
1313
from typing import (
1414
TYPE_CHECKING,
15+
Any,
1516
Generic,
1617
final,
1718
)
@@ -319,6 +320,7 @@ def _cython_op_ndim_compat(
319320
comp_ids: np.ndarray,
320321
mask: npt.NDArray[np.bool_] | None = None,
321322
result_mask: npt.NDArray[np.bool_] | None = None,
323+
initial: Any = 0,
322324
**kwargs,
323325
) -> np.ndarray:
324326
if values.ndim == 1:
@@ -335,6 +337,7 @@ def _cython_op_ndim_compat(
335337
comp_ids=comp_ids,
336338
mask=mask,
337339
result_mask=result_mask,
340+
initial=initial,
338341
**kwargs,
339342
)
340343
if res.shape[0] == 1:
@@ -350,6 +353,7 @@ def _cython_op_ndim_compat(
350353
comp_ids=comp_ids,
351354
mask=mask,
352355
result_mask=result_mask,
356+
initial=initial,
353357
**kwargs,
354358
)
355359

@@ -363,6 +367,7 @@ def _call_cython_op(
363367
comp_ids: np.ndarray,
364368
mask: npt.NDArray[np.bool_] | None,
365369
result_mask: npt.NDArray[np.bool_] | None,
370+
initial: Any = 0,
366371
**kwargs,
367372
) -> np.ndarray: # np.ndarray[ndim=2]
368373
orig_values = values
@@ -420,6 +425,10 @@ def _call_cython_op(
420425
"sum",
421426
"median",
422427
]:
428+
if self.how == "sum":
429+
# pass in through kwargs only for sum (other functions don't have
430+
# the keyword)
431+
kwargs["initial"] = initial
423432
func(
424433
out=result,
425434
counts=counts,

pandas/tests/groupby/test_categorical.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ def f(a):
3232
return a
3333

3434
index = MultiIndex.from_product(map(f, args), names=names)
35+
if isinstance(fill_value, dict):
36+
# fill_value is a dict mapping column names to fill values
37+
# -> reindex column by column (reindex itself does not support this)
38+
res = {}
39+
for col in result.columns:
40+
res[col] = result[col].reindex(index, fill_value=fill_value[col])
41+
return DataFrame(res, index=index).sort_index()
42+
3543
return result.reindex(index, fill_value=fill_value).sort_index()
3644

3745

@@ -317,18 +325,14 @@ def test_apply(ordered):
317325
tm.assert_series_equal(result, expected)
318326

319327

320-
def test_observed(request, using_infer_string, observed):
328+
def test_observed(observed, using_infer_string):
321329
# multiple groupers, don't re-expand the output space
322330
# of the grouper
323331
# gh-14942 (implement)
324332
# gh-10132 (back-compat)
325333
# gh-8138 (back-compat)
326334
# gh-8869
327335

328-
if using_infer_string and not observed:
329-
# TODO(infer_string) this fails with filling the string column with 0
330-
request.applymarker(pytest.mark.xfail(reason="TODO(infer_string)"))
331-
332336
cat1 = Categorical(["a", "a", "b", "b"], categories=["a", "b", "z"], ordered=True)
333337
cat2 = Categorical(["c", "d", "c", "d"], categories=["c", "d", "y"], ordered=True)
334338
df = DataFrame({"A": cat1, "B": cat2, "values": [1, 2, 3, 4]})
@@ -356,7 +360,10 @@ def test_observed(request, using_infer_string, observed):
356360
result = gb.sum()
357361
if not observed:
358362
expected = cartesian_product_for_groupers(
359-
expected, [cat1, cat2], list("AB"), fill_value=0
363+
expected,
364+
[cat1, cat2],
365+
list("AB"),
366+
fill_value={"values": 0, "C": ""} if using_infer_string else 0,
360367
)
361368

362369
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/test_timegrouper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_groupby_with_timegrouper(self, using_infer_string):
108108
unit=df.index.unit,
109109
)
110110
expected = DataFrame(
111-
{"Buyer": 0, "Quantity": 0},
111+
{"Buyer": "" if using_infer_string else 0, "Quantity": 0},
112112
index=exp_dti,
113113
)
114114
# Cast to object/str to avoid implicit cast when setting

0 commit comments

Comments
 (0)