Skip to content

Commit d3a018d

Browse files
ENH: Fix by in DataFrame.plot.hist and DataFrame.plot.box (#28373)
1 parent 921bdfa commit d3a018d

File tree

7 files changed

+660
-17
lines changed

7 files changed

+660
-17
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ enhancement2
2929

3030
Other enhancements
3131
^^^^^^^^^^^^^^^^^^
32+
- Add support for assigning values to ``by`` argument in :meth:`DataFrame.plot.hist` and :meth:`DataFrame.plot.box` (:issue:`15079`)
3233
- :meth:`Series.sample`, :meth:`DataFrame.sample`, and :meth:`.GroupBy.sample` now accept a ``np.random.Generator`` as input to ``random_state``. A generator will be more performant, especially with ``replace=False`` (:issue:`38100`)
3334
- Additional options added to :meth:`.Styler.bar` to control alignment and display (:issue:`26070`)
3435
- :meth:`Series.ewm`, :meth:`DataFrame.ewm`, now support a ``method`` argument with a ``'table'`` option that performs the windowing operation over an entire :class:`DataFrame`. See :ref:`Window Overview <window.overview>` for performance and functional benefits (:issue:`42273`)

pandas/plotting/_core.py

+20
Original file line numberDiff line numberDiff line change
@@ -1237,6 +1237,11 @@ def box(self, by=None, **kwargs):
12371237
----------
12381238
by : str or sequence
12391239
Column in the DataFrame to group by.
1240+
1241+
.. versionchanged:: 1.4.0
1242+
1243+
Previously, `by` is silently ignore and makes no groupings
1244+
12401245
**kwargs
12411246
Additional keywords are documented in
12421247
:meth:`DataFrame.plot`.
@@ -1278,6 +1283,11 @@ def hist(self, by=None, bins=10, **kwargs):
12781283
----------
12791284
by : str or sequence, optional
12801285
Column in the DataFrame to group by.
1286+
1287+
.. versionchanged:: 1.4.0
1288+
1289+
Previously, `by` is silently ignore and makes no groupings
1290+
12811291
bins : int, default 10
12821292
Number of histogram bins to be used.
12831293
**kwargs
@@ -1309,6 +1319,16 @@ def hist(self, by=None, bins=10, **kwargs):
13091319
... columns = ['one'])
13101320
>>> df['two'] = df['one'] + np.random.randint(1, 7, 6000)
13111321
>>> ax = df.plot.hist(bins=12, alpha=0.5)
1322+
1323+
A grouped histogram can be generated by providing the parameter `by` (which
1324+
can be a column name, or a list of column names):
1325+
1326+
.. plot::
1327+
:context: close-figs
1328+
1329+
>>> age_list = [8, 10, 12, 14, 72, 74, 76, 78, 20, 25, 30, 35, 60, 85]
1330+
>>> df = pd.DataFrame({"gender": list("MMMMMMMMFFFFFF"), "age": age_list})
1331+
>>> ax = df.plot.hist(column=["age"], by="gender", figsize=(10, 8))
13121332
"""
13131333
return self(kind="hist", by=by, bins=bins, **kwargs)
13141334

pandas/plotting/_matplotlib/boxplot.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
LinePlot,
1919
MPLPlot,
2020
)
21+
from pandas.plotting._matplotlib.groupby import create_iter_data_given_by
2122
from pandas.plotting._matplotlib.style import get_standard_colors
2223
from pandas.plotting._matplotlib.tools import (
2324
create_subplots,
@@ -135,18 +136,37 @@ def _make_plot(self):
135136
if self.subplots:
136137
self._return_obj = pd.Series(dtype=object)
137138

138-
for i, (label, y) in enumerate(self._iter_data()):
139+
# Re-create iterated data if `by` is assigned by users
140+
data = (
141+
create_iter_data_given_by(self.data, self._kind)
142+
if self.by is not None
143+
else self.data
144+
)
145+
146+
for i, (label, y) in enumerate(self._iter_data(data=data)):
139147
ax = self._get_ax(i)
140148
kwds = self.kwds.copy()
141149

150+
# When by is applied, show title for subplots to know which group it is
151+
# just like df.boxplot, and need to apply T on y to provide right input
152+
if self.by is not None:
153+
y = y.T
154+
ax.set_title(pprint_thing(label))
155+
156+
# When `by` is assigned, the ticklabels will become unique grouped
157+
# values, instead of label which is used as subtitle in this case.
158+
ticklabels = [
159+
pprint_thing(col) for col in self.data.columns.levels[0]
160+
]
161+
else:
162+
ticklabels = [pprint_thing(label)]
163+
142164
ret, bp = self._plot(
143165
ax, y, column_num=i, return_type=self.return_type, **kwds
144166
)
145167
self.maybe_color_bp(bp)
146168
self._return_obj[label] = ret
147-
148-
label = [pprint_thing(label)]
149-
self._set_ticklabels(ax, label)
169+
self._set_ticklabels(ax, ticklabels)
150170
else:
151171
y = self.data.values.T
152172
ax = self._get_ax(0)

pandas/plotting/_matplotlib/core.py

+52-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from matplotlib.artist import Artist
1010
import numpy as np
1111

12+
from pandas._typing import IndexLabel
1213
from pandas.errors import AbstractMethodError
1314
from pandas.util._decorators import cache_readonly
1415

@@ -38,10 +39,12 @@
3839
)
3940

4041
import pandas.core.common as com
42+
from pandas.core.frame import DataFrame
4143

4244
from pandas.io.formats.printing import pprint_thing
4345
from pandas.plotting._matplotlib.compat import mpl_ge_3_0_0
4446
from pandas.plotting._matplotlib.converter import register_pandas_matplotlib_converters
47+
from pandas.plotting._matplotlib.groupby import reconstruct_data_with_by
4548
from pandas.plotting._matplotlib.style import get_standard_colors
4649
from pandas.plotting._matplotlib.timeseries import (
4750
decorate_axes,
@@ -99,7 +102,7 @@ def __init__(
99102
self,
100103
data,
101104
kind=None,
102-
by=None,
105+
by: IndexLabel | None = None,
103106
subplots=False,
104107
sharex=None,
105108
sharey=False,
@@ -124,13 +127,42 @@ def __init__(
124127
table=False,
125128
layout=None,
126129
include_bool=False,
130+
column: IndexLabel | None = None,
127131
**kwds,
128132
):
129133

130134
import matplotlib.pyplot as plt
131135

132136
self.data = data
133-
self.by = by
137+
138+
# if users assign an empty list or tuple, raise `ValueError`
139+
# similar to current `df.box` and `df.hist` APIs.
140+
if by in ([], ()):
141+
raise ValueError("No group keys passed!")
142+
self.by = com.maybe_make_list(by)
143+
144+
# Assign the rest of columns into self.columns if by is explicitly defined
145+
# while column is not, only need `columns` in hist/box plot when it's DF
146+
# TODO: Might deprecate `column` argument in future PR (#28373)
147+
if isinstance(data, DataFrame):
148+
if column:
149+
self.columns = com.maybe_make_list(column)
150+
else:
151+
if self.by is None:
152+
self.columns = [
153+
col for col in data.columns if is_numeric_dtype(data[col])
154+
]
155+
else:
156+
self.columns = [
157+
col
158+
for col in data.columns
159+
if col not in self.by and is_numeric_dtype(data[col])
160+
]
161+
162+
# For `hist` plot, need to get grouped original data before `self.data` is
163+
# updated later
164+
if self.by is not None and self._kind == "hist":
165+
self._grouped = data.groupby(self.by)
134166

135167
self.kind = kind
136168

@@ -139,7 +171,9 @@ def __init__(
139171
self.subplots = subplots
140172

141173
if sharex is None:
142-
if ax is None:
174+
175+
# if by is defined, subplots are used and sharex should be False
176+
if ax is None and by is None:
143177
self.sharex = True
144178
else:
145179
# if we get an axis, the users should do the visibility
@@ -273,8 +307,15 @@ def _iter_data(self, data=None, keep_index=False, fillna=None):
273307

274308
@property
275309
def nseries(self) -> int:
310+
311+
# When `by` is explicitly assigned, grouped data size will be defined, and
312+
# this will determine number of subplots to have, aka `self.nseries`
276313
if self.data.ndim == 1:
277314
return 1
315+
elif self.by is not None and self._kind == "hist":
316+
return len(self._grouped)
317+
elif self.by is not None and self._kind == "box":
318+
return len(self.columns)
278319
else:
279320
return self.data.shape[1]
280321

@@ -420,6 +461,14 @@ def _compute_plot_data(self):
420461
if label is None and data.name is None:
421462
label = "None"
422463
data = data.to_frame(name=label)
464+
elif self._kind in ("hist", "box"):
465+
cols = self.columns if self.by is None else self.columns + self.by
466+
data = data.loc[:, cols]
467+
468+
# GH15079 reconstruct data if by is defined
469+
if self.by is not None:
470+
self.subplots = True
471+
data = reconstruct_data_with_by(self.data, by=self.by, cols=self.columns)
423472

424473
# GH16953, _convert is needed as fallback, for ``Series``
425474
# with ``dtype == object``
+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
5+
from pandas._typing import (
6+
Dict,
7+
IndexLabel,
8+
)
9+
10+
from pandas.core.dtypes.missing import remove_na_arraylike
11+
12+
from pandas import (
13+
DataFrame,
14+
MultiIndex,
15+
Series,
16+
concat,
17+
)
18+
19+
20+
def create_iter_data_given_by(
21+
data: DataFrame, kind: str = "hist"
22+
) -> Dict[str, DataFrame | Series]:
23+
"""
24+
Create data for iteration given `by` is assigned or not, and it is only
25+
used in both hist and boxplot.
26+
27+
If `by` is assigned, return a dictionary of DataFrames in which the key of
28+
dictionary is the values in groups.
29+
If `by` is not assigned, return input as is, and this preserves current
30+
status of iter_data.
31+
32+
Parameters
33+
----------
34+
data : reformatted grouped data from `_compute_plot_data` method.
35+
kind : str, plot kind. This function is only used for `hist` and `box` plots.
36+
37+
Returns
38+
-------
39+
iter_data : DataFrame or Dictionary of DataFrames
40+
41+
Examples
42+
--------
43+
If `by` is assigned:
44+
45+
>>> import numpy as np
46+
>>> tuples = [('h1', 'a'), ('h1', 'b'), ('h2', 'a'), ('h2', 'b')]
47+
>>> mi = MultiIndex.from_tuples(tuples)
48+
>>> value = [[1, 3, np.nan, np.nan],
49+
... [3, 4, np.nan, np.nan], [np.nan, np.nan, 5, 6]]
50+
>>> data = DataFrame(value, columns=mi)
51+
>>> create_iter_data_given_by(data)
52+
{'h1': DataFrame({'a': [1, 3, np.nan], 'b': [3, 4, np.nan]}),
53+
'h2': DataFrame({'a': [np.nan, np.nan, 5], 'b': [np.nan, np.nan, 6]})}
54+
"""
55+
56+
# For `hist` plot, before transformation, the values in level 0 are values
57+
# in groups and subplot titles, and later used for column subselection and
58+
# iteration; For `box` plot, values in level 1 are column names to show,
59+
# and are used for iteration and as subplots titles.
60+
if kind == "hist":
61+
level = 0
62+
else:
63+
level = 1
64+
65+
# Select sub-columns based on the value of level of MI, and if `by` is
66+
# assigned, data must be a MI DataFrame
67+
assert isinstance(data.columns, MultiIndex)
68+
return {
69+
col: data.loc[:, data.columns.get_level_values(level) == col]
70+
for col in data.columns.levels[level]
71+
}
72+
73+
74+
def reconstruct_data_with_by(
75+
data: DataFrame, by: IndexLabel, cols: IndexLabel
76+
) -> DataFrame:
77+
"""
78+
Internal function to group data, and reassign multiindex column names onto the
79+
result in order to let grouped data be used in _compute_plot_data method.
80+
81+
Parameters
82+
----------
83+
data : Original DataFrame to plot
84+
by : grouped `by` parameter selected by users
85+
cols : columns of data set (excluding columns used in `by`)
86+
87+
Returns
88+
-------
89+
Output is the reconstructed DataFrame with MultiIndex columns. The first level
90+
of MI is unique values of groups, and second level of MI is the columns
91+
selected by users.
92+
93+
Examples
94+
--------
95+
>>> d = {'h': ['h1', 'h1', 'h2'], 'a': [1, 3, 5], 'b': [3, 4, 6]}
96+
>>> df = DataFrame(d)
97+
>>> reconstruct_data_with_by(df, by='h', cols=['a', 'b'])
98+
h1 h2
99+
a b a b
100+
0 1 3 NaN NaN
101+
1 3 4 NaN NaN
102+
2 NaN NaN 5 6
103+
"""
104+
grouped = data.groupby(by)
105+
106+
data_list = []
107+
for key, group in grouped:
108+
columns = MultiIndex.from_product([[key], cols])
109+
sub_group = group[cols]
110+
sub_group.columns = columns
111+
data_list.append(sub_group)
112+
113+
data = concat(data_list, axis=1)
114+
return data
115+
116+
117+
def reformat_hist_y_given_by(
118+
y: Series | np.ndarray, by: IndexLabel | None
119+
) -> Series | np.ndarray:
120+
"""Internal function to reformat y given `by` is applied or not for hist plot.
121+
122+
If by is None, input y is 1-d with NaN removed; and if by is not None, groupby
123+
will take place and input y is multi-dimensional array.
124+
"""
125+
if by is not None and len(y.shape) > 1:
126+
return np.array([remove_na_arraylike(col) for col in y.T]).T
127+
return remove_na_arraylike(y)

0 commit comments

Comments
 (0)