Skip to content

Commit 6cb50c2

Browse files
committed
ENH: Groupby.plot enhancement
1 parent f5cab4e commit 6cb50c2

File tree

4 files changed

+692
-360
lines changed

4 files changed

+692
-360
lines changed

pandas/core/groupby.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3297,9 +3297,13 @@ def _apply_to_column_groupbys(self, func):
32973297
in self._iterate_column_groupbys()),
32983298
keys=self._selected_obj.columns, axis=1)
32993299

3300-
from pandas.tools.plotting import boxplot_frame_groupby
3300+
3301+
from pandas.tools.plotting import boxplot_frame_groupby, plot_grouped
33013302
DataFrameGroupBy.boxplot = boxplot_frame_groupby
33023303

3304+
SeriesGroupBy.plot = plot_grouped
3305+
DataFrameGroupBy.plot = plot_grouped
3306+
33033307

33043308
class PanelGroupBy(NDFrameGroupBy):
33053309

pandas/tests/test_graphics.py

+99-6
Original file line numberDiff line numberDiff line change
@@ -1671,6 +1671,7 @@ def test_line_lim(self):
16711671
self.assertFalse(hasattr(ax, 'right_ax'))
16721672
xmin, xmax = ax.get_xlim()
16731673
lines = ax.get_lines()
1674+
self.assertTrue(hasattr(ax, 'left_ax'))
16741675
self.assertEqual(xmin, lines[0].get_data()[0][0])
16751676
self.assertEqual(xmax, lines[0].get_data()[0][-1])
16761677

@@ -1695,10 +1696,8 @@ def test_area_lim(self):
16951696
@slow
16961697
def test_bar_colors(self):
16971698
import matplotlib.pyplot as plt
1698-
16991699
default_colors = plt.rcParams.get('axes.color_cycle')
17001700

1701-
17021701
df = DataFrame(randn(5, 5))
17031702
ax = df.plot(kind='bar')
17041703
self._check_colors(ax.patches[::5], facecolors=default_colors[:5])
@@ -3209,7 +3208,6 @@ def test_hexbin_cmap(self):
32093208
@slow
32103209
def test_no_color_bar(self):
32113210
df = self.hexbin_df
3212-
32133211
ax = df.plot(kind='hexbin', x='A', y='B', colorbar=None)
32143212
self.assertIs(ax.collections[0].colorbar, None)
32153213

@@ -3572,9 +3570,8 @@ def test_grouped_plot_fignums(self):
35723570
df = DataFrame({'height': height, 'weight': weight, 'gender': gender})
35733571
gb = df.groupby('gender')
35743572

3575-
res = gb.plot()
3576-
self.assertEqual(len(self.plt.get_fignums()), 2)
3577-
self.assertEqual(len(res), 2)
3573+
with tm.assertRaisesRegexp(ValueError, "To plot DataFrameGroupBy, specify 'suplots=True'"):
3574+
res = gb.plot()
35783575
tm.close()
35793576

35803577
res = gb.boxplot(return_type='axes')
@@ -3953,6 +3950,102 @@ def test_plotting_with_float_index_works(self):
39533950
df.groupby('def')['val'].apply(lambda x: x.plot())
39543951
tm.close()
39553952

3953+
def test_line_groupby(self):
3954+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3955+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3956+
grouped = df.groupby(by='by')
3957+
3958+
# SeriesGroupBy
3959+
sgb = grouped['A']
3960+
ax = _check_plot_works(sgb.plot, colors=['r', 'g', 'b'])
3961+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3962+
self._check_colors(ax.get_lines(), linecolors=['r', 'g', 'b'])
3963+
3964+
axes = _check_plot_works(sgb.plot, subplots=True)
3965+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3966+
self._check_legend_labels(axes[0], labels=['Group 0'])
3967+
self._check_legend_labels(axes[1], labels=['Group 1'])
3968+
self._check_legend_labels(axes[2], labels=['Group 2'])
3969+
3970+
# DataFrameGroupBy
3971+
import matplotlib.pyplot as plt
3972+
default_colors = plt.rcParams.get('axes.color_cycle')
3973+
3974+
axes = _check_plot_works(grouped.plot, subplots=True)
3975+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3976+
for ax in axes:
3977+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3978+
self._check_colors(ax.get_lines(), linecolors=default_colors[:5])
3979+
3980+
axes = _check_plot_works(grouped.plot, subplots=True, axis=1)
3981+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3982+
for ax in axes:
3983+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3984+
self._check_colors(ax.get_lines(), linecolors=default_colors[:3])
3985+
3986+
def test_hist_groupby(self):
3987+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3988+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3989+
grouped = df.groupby(by='by')
3990+
3991+
# SeriesGroupBy
3992+
sgb = grouped['A']
3993+
ax = sgb.plot(kind='hist', color=['r', 'g', 'b'])
3994+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3995+
self._check_colors(ax.patches[::10], facecolors=['r', 'g', 'b'])
3996+
3997+
import matplotlib.pyplot as plt
3998+
default_colors = plt.rcParams.get('axes.color_cycle')
3999+
axes = sgb.plot(kind='hist', subplots=True)
4000+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
4001+
self._check_legend_labels(axes[0], labels=['Group 0'])
4002+
self._check_legend_labels(axes[1], labels=['Group 1'])
4003+
self._check_legend_labels(axes[2], labels=['Group 2'])
4004+
self._check_colors([axes[0].patches[0]], facecolors=default_colors[0])
4005+
self._check_colors([axes[1].patches[0]], facecolors=default_colors[1])
4006+
self._check_colors([axes[2].patches[0]], facecolors=default_colors[2])
4007+
4008+
# DataFrameGroupBy
4009+
axes = grouped.plot(kind='hist', subplots=True)
4010+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
4011+
for ax in axes:
4012+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
4013+
self._check_colors(axes[0].patches[::10], facecolors=default_colors[:5])
4014+
self._check_colors(axes[1].patches[::10], facecolors=default_colors[:5])
4015+
self._check_colors(axes[2].patches[::10], facecolors=default_colors[:5])
4016+
4017+
axes = grouped.plot(kind='hist', subplots=True, axis=1)
4018+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
4019+
for ax in axes:
4020+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
4021+
self._check_colors(ax.patches[::10], facecolors=['b', 'g', 'r'])
4022+
4023+
def test_scatter_groupby(self):
4024+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
4025+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
4026+
grouped = df.groupby(by='by')
4027+
4028+
ax = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=False)
4029+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
4030+
4031+
axes = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=True)
4032+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
4033+
self._check_legend_labels(axes[0], labels=['Group 0'])
4034+
self._check_legend_labels(axes[1], labels=['Group 1'])
4035+
self._check_legend_labels(axes[2], labels=['Group 2'])
4036+
4037+
def test_hexbin_groupby(self):
4038+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
4039+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
4040+
grouped = df.groupby(by='by')
4041+
4042+
msg = "To plot DataFrameGroupBy, specify 'suplots=True'"
4043+
with tm.assertRaisesRegexp(ValueError, msg):
4044+
grouped.plot(kind='hexbin', x='A', y='B', subplots=False)
4045+
4046+
axes = _check_plot_works(grouped.plot, kind='hexbin', x='A', y='B', subplots=True)
4047+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
4048+
39564049

39574050
def assert_is_valid_plot_return_object(objs):
39584051
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)