Skip to content

Commit f7b72e1

Browse files
committed
ENH: Groupby.plot enhancement
1 parent f9f88b2 commit f7b72e1

File tree

4 files changed

+665
-259
lines changed

4 files changed

+665
-259
lines changed

pandas/core/groupby.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -3273,9 +3273,13 @@ def _apply_to_column_groupbys(self, func):
32733273
in self._iterate_column_groupbys()),
32743274
keys=self._selected_obj.columns, axis=1)
32753275

3276-
from pandas.tools.plotting import boxplot_frame_groupby
3276+
3277+
from pandas.tools.plotting import boxplot_frame_groupby, plot_grouped
32773278
DataFrameGroupBy.boxplot = boxplot_frame_groupby
32783279

3280+
SeriesGroupBy.plot = plot_grouped
3281+
DataFrameGroupBy.plot = plot_grouped
3282+
32793283

32803284
class PanelGroupBy(NDFrameGroupBy):
32813285

pandas/tests/test_graphics.py

+94-4
Original file line numberDiff line numberDiff line change
@@ -1635,6 +1635,7 @@ def test_line_lim(self):
16351635
self.assertFalse(hasattr(ax, 'right_ax'))
16361636
xmin, xmax = ax.get_xlim()
16371637
lines = ax.get_lines()
1638+
self.assertTrue(hasattr(ax, 'left_ax'))
16381639
self.assertEqual(xmin, lines[0].get_data()[0][0])
16391640
self.assertEqual(xmax, lines[0].get_data()[0][-1])
16401641

@@ -3072,7 +3073,6 @@ def test_hexbin_cmap(self):
30723073
@slow
30733074
def test_no_color_bar(self):
30743075
df = self.hexbin_df
3075-
30763076
ax = df.plot(kind='hexbin', x='A', y='B', colorbar=None)
30773077
self.assertIs(ax.collections[0].colorbar, None)
30783078

@@ -3435,9 +3435,8 @@ def test_grouped_plot_fignums(self):
34353435
df = DataFrame({'height': height, 'weight': weight, 'gender': gender})
34363436
gb = df.groupby('gender')
34373437

3438-
res = gb.plot()
3439-
self.assertEqual(len(self.plt.get_fignums()), 2)
3440-
self.assertEqual(len(res), 2)
3438+
with tm.assertRaisesRegexp(ValueError, "To plot DataFrameGroupBy, specify 'suplots=True'"):
3439+
res = gb.plot()
34413440
tm.close()
34423441

34433442
res = gb.boxplot(return_type='axes')
@@ -3816,6 +3815,97 @@ def test_plotting_with_float_index_works(self):
38163815
df.groupby('def')['val'].apply(lambda x: x.plot())
38173816
tm.close()
38183817

3818+
def test_line_groupby(self):
3819+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3820+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3821+
grouped = df.groupby(by='by')
3822+
3823+
# SeriesGroupBy
3824+
sgb = grouped['A']
3825+
ax = _check_plot_works(sgb.plot, colors=['r', 'g', 'b'])
3826+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3827+
self._check_colors(ax.get_lines(), linecolors=['r', 'g', 'b'])
3828+
3829+
axes = _check_plot_works(sgb.plot, subplots=True)
3830+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3831+
self._check_legend_labels(axes[0], labels=['Group 0'])
3832+
self._check_legend_labels(axes[1], labels=['Group 1'])
3833+
self._check_legend_labels(axes[2], labels=['Group 2'])
3834+
3835+
# DataFrameGroupBy
3836+
axes = _check_plot_works(grouped.plot, subplots=True)
3837+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3838+
for ax in axes:
3839+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3840+
self._check_colors(ax.get_lines(), linecolors=['k', 'k', 'k', 'k', 'k'])
3841+
3842+
axes = _check_plot_works(grouped.plot, subplots=True, axis=1)
3843+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3844+
for ax in axes:
3845+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3846+
self._check_colors(ax.get_lines(), linecolors=['k', 'k', 'k'])
3847+
3848+
def test_hist_groupby(self):
3849+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3850+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3851+
grouped = df.groupby(by='by')
3852+
3853+
# SeriesGroupBy
3854+
sgb = grouped['A']
3855+
ax = sgb.plot(kind='hist', colors=['r', 'g', 'b'])
3856+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3857+
self._check_colors(ax.patches[::10], facecolors=['r', 'g', 'b'])
3858+
3859+
axes = sgb.plot(kind='hist', subplots=True)
3860+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3861+
self._check_legend_labels(axes[0], labels=['Group 0'])
3862+
self._check_legend_labels(axes[1], labels=['Group 1'])
3863+
self._check_legend_labels(axes[2], labels=['Group 2'])
3864+
self._check_colors([axes[0].patches[0]], facecolors=['b'])
3865+
self._check_colors([axes[1].patches[0]], facecolors=['b'])
3866+
self._check_colors([axes[2].patches[0]], facecolors=['b'])
3867+
3868+
# DataFrameGroupBy
3869+
axes = grouped.plot(kind='hist', subplots=True)
3870+
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
3871+
for ax in axes:
3872+
self._check_legend_labels(ax, labels=['A', 'B', 'C', 'D', 'E'])
3873+
# self._check_colors(axes[0].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3874+
# self._check_colors(axes[1].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3875+
# self._check_colors(axes[2].patches(), facecolors=['k', 'k', 'k', 'k', 'k'])
3876+
3877+
axes = grouped.plot(kind='hist', subplots=True, axis=1)
3878+
self._check_axes_shape(axes, axes_num=5, layout=(5, 1))
3879+
for ax in axes:
3880+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3881+
self._check_colors(ax.patches[::10], facecolors=['b', 'g', 'r'])
3882+
3883+
def test_scatter_groupby(self):
3884+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3885+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3886+
grouped = df.groupby(by='by')
3887+
3888+
ax = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=False)
3889+
self._check_legend_labels(ax, labels=['Group 0', 'Group 1', 'Group 2'])
3890+
3891+
axes = _check_plot_works(grouped.plot, kind='scatter', x='A', y='B', subplots=True)
3892+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
3893+
self._check_legend_labels(axes[0], labels=['Group 0'])
3894+
self._check_legend_labels(axes[1], labels=['Group 1'])
3895+
self._check_legend_labels(axes[2], labels=['Group 2'])
3896+
3897+
def test_hexbin_groupby(self):
3898+
df = DataFrame(np.random.rand(30, 5), columns=['A', 'B', 'C', 'D', 'E'])
3899+
df['by'] = ['Group {0}'.format(i) for i in [0]*10 + [1]*10 + [2]*10]
3900+
grouped = df.groupby(by='by')
3901+
3902+
msg = "To plot DataFrameGroupBy, specify 'suplots=True'"
3903+
with tm.assertRaisesRegexp(ValueError, msg):
3904+
grouped.plot(kind='hexbin', x='A', y='B', subplots=False)
3905+
3906+
axes = _check_plot_works(grouped.plot, kind='hexbin', x='A', y='B', subplots=True)
3907+
self._check_axes_shape(axes, axes_num=3, layout=(1, 3))
3908+
38193909

38203910
def assert_is_valid_plot_return_object(objs):
38213911
import matplotlib.pyplot as plt

0 commit comments

Comments
 (0)