Skip to content

Commit cbae226

Browse files
committed
ENH: Groupby.plot enhancement
1 parent 22e0e5d commit cbae226

File tree

4 files changed

+676
-338
lines changed

4 files changed

+676
-338
lines changed

pandas/core/groupby.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3297,9 +3297,13 @@ def _apply_to_column_groupbys(self, func):
32973297
in self._iterate_column_groupbys()),
32983298
keys=self._selected_obj.columns, axis=1)
32993299

3300-
from pandas.tools.plotting import boxplot_frame_groupby
3300+
3301+
from pandas.tools.plotting import boxplot_frame_groupby, plot_grouped
33013302
DataFrameGroupBy.boxplot = boxplot_frame_groupby
33023303

3304+
SeriesGroupBy.plot = plot_grouped
3305+
DataFrameGroupBy.plot = plot_grouped
3306+
33033307

33043308
class PanelGroupBy(NDFrameGroupBy):
33053309

pandas/tests/test_graphics.py

+99-6
Original file line numberDiff line numberDiff line change
@@ -1648,6 +1648,7 @@ def test_line_lim(self):
16481648
self.assertFalse(hasattr(ax, 'right_ax'))
16491649
xmin, xmax = ax.get_xlim()
16501650
lines = ax.get_lines()
1651+
self.assertTrue(hasattr(ax, 'left_ax'))
16511652
self.assertEqual(xmin, lines[0].get_data()[0][0])
16521653
self.assertEqual(xmax, lines[0].get_data()[0][-1])
16531654

@@ -1672,10 +1673,8 @@ def test_area_lim(self):
16721673
@slow
16731674
def test_bar_colors(self):
16741675
import matplotlib.pyplot as plt
1675-
16761676
default_colors = plt.rcParams.get('axes.color_cycle')
16771677

1678-
16791678
df = DataFrame(randn(5, 5))
16801679
ax = df.plot(kind='bar')
16811680
self._check_colors(ax.patches[::5], facecolors=default_colors[:5])
@@ -3186,7 +3185,6 @@ def test_hexbin_cmap(self):
31863185
@slow
31873186
def test_no_color_bar(self):
31883187
df = self.hexbin_df
3189-
31903188
ax = df.plot(kind='hexbin', x='A', y='B', colorbar=None)
31913189
self.assertIs(ax.collections[0].colorbar, None)
31923190

@@ -3549,9 +3547,8 @@ def test_grouped_plot_fignums(self):
35493547
df = DataFrame({'height': height, 'weight': weight, 'gender': gender})
35503548
gb = df.groupby('gender')
35513549

3552-
res = gb.plot()
3553-
self.assertEqual(len(self.plt.get_fignums()), 2)
3554-
self.assertEqual(len(res), 2)
3550+
with tm.assertRaisesRegexp(ValueError, "To plot DataFrameGroupBy, specify 'suplots=True'"):
3551+
res = gb.plot()
35553552
tm.close()
35563553

35573554
res = gb.boxplot(return_type='axes')
@@ -3930,6 +3927,102 @@ def test_plotting_with_float_index_works(self):
39303927
df.groupby('def')['val'].apply(lambda x: x.plot())
39313928
tm.close()
39323929

3930+
def test_line_groupby(self):
3931+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3932+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3933+
grouped = df.groupby(by='by')
3934+
3935+
# SeriesGroupBy
3936+
sgb = grouped['A']
3937+
ax = _check_plot_works(sgb.plot, colors=['r', 'g', 'b'])
3938+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3939+
self._check_colors(ax.get_lines(), linecolors=['r', 'g', 'b'])
3940+
3941+
axes = _check_plot_works(sgb.plot, subplots=True)
3942+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3943+
self._check_legend_labels(axes[0], labels=['Group 0'])
3944+
self._check_legend_labels(axes[1], labels=['Group 1'])
3945+
self._check_legend_labels(axes[2], labels=['Group 2'])
3946+
3947+
# DataFrameGroupBy
3948+
import matplotlib.pyplot as plt
3949+
default_colors = plt.rcParams.get('axes.color_cycle')
3950+
3951+
axes = _check_plot_works(grouped.plot, subplots=True)
3952+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3953+
for ax in axes:
3954+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3955+
self._check_colors(ax.get_lines(), linecolors=default_colors[:5])
3956+
3957+
axes = _check_plot_works(grouped.plot, subplots=True, axis=1)
3958+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3959+
for ax in axes:
3960+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3961+
self._check_colors(ax.get_lines(), linecolors=default_colors[:3])
3962+
3963+
def test_hist_groupby(self):
3964+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3965+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3966+
grouped = df.groupby(by='by')
3967+
3968+
# SeriesGroupBy
3969+
sgb = grouped['A']
3970+
ax = sgb.plot(kind='hist', color=['r', 'g', 'b'])
3971+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3972+
self._check_colors(ax.patches[::10], facecolors=['r', 'g', 'b'])
3973+
3974+
import matplotlib.pyplot as plt
3975+
default_colors = plt.rcParams.get('axes.color_cycle')
3976+
axes = sgb.plot(kind='hist', subplots=True)
3977+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3978+
self._check_legend_labels(axes[0], labels=['Group 0'])
3979+
self._check_legend_labels(axes[1], labels=['Group 1'])
3980+
self._check_legend_labels(axes[2], labels=['Group 2'])
3981+
self._check_colors([axes[0].patches[0]], facecolors=default_colors[0])
3982+
self._check_colors([axes[1].patches[0]], facecolors=default_colors[1])
3983+
self._check_colors([axes[2].patches[0]], facecolors=default_colors[2])
3984+
3985+
# DataFrameGroupBy
3986+
axes = grouped.plot(kind='hist', subplots=True)
3987+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3988+
for ax in axes:
3989+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3990+
self._check_colors(axes[0].patches[::10], facecolors=default_colors[:5])
3991+
self._check_colors(axes[1].patches[::10], facecolors=default_colors[:5])
3992+
self._check_colors(axes[2].patches[::10], facecolors=default_colors[:5])
3993+
3994+
axes = grouped.plot(kind='hist', subplots=True, axis=1)
3995+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3996+
for ax in axes:
3997+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3998+
self._check_colors(ax.patches[::10], facecolors=['b', 'g', 'r'])
3999+
4000+
def test_scatter_groupby(self):
4001+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
4002+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
4003+
grouped = df.groupby(by='by')
4004+
4005+
ax = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=False)
4006+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
4007+
4008+
axes = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=True)
4009+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
4010+
self._check_legend_labels(axes[0], labels=['Group 0'])
4011+
self._check_legend_labels(axes[1], labels=['Group 1'])
4012+
self._check_legend_labels(axes[2], labels=['Group 2'])
4013+
4014+
def test_hexbin_groupby(self):
4015+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
4016+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
4017+
grouped = df.groupby(by='by')
4018+
4019+
msg = "To plot DataFrameGroupBy, specify 'suplots=True'"
4020+
with tm.assertRaisesRegexp(ValueError, msg):
4021+
grouped.plot(kind='hexbin', x='A', y='B', subplots=False)
4022+
4023+
axes = _check_plot_works(grouped.plot, kind='hexbin', x='A', y='B', subplots=True)
4024+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
4025+
39334026

39344027
def assert_is_valid_plot_return_object(objs):
39354028
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)