-
-
Notifications
You must be signed in to change notification settings - Fork 18.7k
ENH: speed up wide DataFrame.line plots by using a single LineCollection #61764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7e9cbd8
6910da7
8b7b0df
d9ac7a6
a490e24
308f6a6
08c0fa9
1a6f47b
4e26644
3badad1
706fb5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,6 +99,10 @@ | |
Series, | ||
) | ||
|
||
import itertools | ||
|
||
from matplotlib.collections import LineCollection | ||
|
||
|
||
def holds_integer(column: Index) -> bool: | ||
return column.inferred_type in {"integer", "mixed-integer"} | ||
|
@@ -1549,66 +1553,115 @@ def __init__(self, data, **kwargs) -> None: | |
self.data = self.data.fillna(value=0) | ||
|
||
def _make_plot(self, fig: Figure) -> None: | ||
""" | ||
Draw a DataFrame line plot. For very wide frames (> 200 columns) that are | ||
*not* time-series and have no stacking or error bars, all columns are | ||
rendered with a single LineCollection for a large speed-up while keeping | ||
public behaviour identical to the original per-column path. | ||
|
||
GH#61764 | ||
""" | ||
# decide once whether we can use the LineCollection fast draw | ||
threshold = 200 | ||
use_collection = ( | ||
not self._is_ts_plot() | ||
and not self.stacked | ||
and not com.any_not_none(*self.errors.values()) | ||
and len(self.data.columns) > threshold | ||
) | ||
|
||
# choose ts-plot helper vs. regular helper | ||
if self._is_ts_plot(): | ||
data = maybe_convert_index(self._get_ax(0), self.data) | ||
|
||
x = data.index # dummy, not used | ||
x = data.index # dummy; _ts_plot ignores it | ||
plotf = self._ts_plot | ||
it = data.items() | ||
else: | ||
x = self._get_xticks() | ||
# error: Incompatible types in assignment (expression has type | ||
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has | ||
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]") | ||
plotf = self._plot # type: ignore[assignment] | ||
# error: Incompatible types in assignment (expression has type | ||
# "Iterator[tuple[Hashable, ndarray[Any, Any]]]", variable has | ||
# type "Iterable[tuple[Hashable, Series]]") | ||
it = self._iter_data(data=self.data) # type: ignore[assignment] | ||
|
||
# shared state | ||
stacking_id = self._get_stacking_id() | ||
is_errorbar = com.any_not_none(*self.errors.values()) | ||
|
||
colors = self._get_colors() | ||
segments: list[np.ndarray] = [] # vertices for LineCollection | ||
|
||
# unified per-column loop | ||
for i, (label, y) in enumerate(it): | ||
ax = self._get_ax(i) | ||
ax = self._get_ax(i if not use_collection else 0) | ||
|
||
kwds = self.kwds.copy() | ||
if self.color is not None: | ||
kwds["color"] = self.color | ||
|
||
style, kwds = self._apply_style_colors( | ||
colors, | ||
kwds, | ||
i, | ||
# error: Argument 4 to "_apply_style_colors" of "MPLPlot" has | ||
# incompatible type "Hashable"; expected "str" | ||
label, # type: ignore[arg-type] | ||
) | ||
kwds.update(self._get_errorbars(label=label, index=i)) | ||
|
||
label_str = self._mark_right_label(pprint_thing(label), index=i) | ||
kwds["label"] = label_str | ||
|
||
if use_collection: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still not generally fond of having a different code path if some condition is met, especially since the condition is requires a magic number There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a reasonable concern. is there a downside to always using LineCollection? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mroeschke @jbrockmendel, If we want to completely get rid of the path split and the magic
(No changes for Series plots or other plot types like scatter, area, bar, etc.) ⸻ Advantages: ⸻ Potential Downsides (but manageable): If you're comfortable with this, i can start an implementation. Are there any additional concerns i should keep in mind before coding? Happy to iterate! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @EvMossan the unfortunate situation is that there aren't any maintainers with expertise in matplotlib, so the idea of reviewing everything you described is daunting. Is there a minimal version? |
||
# collect vertices; defer drawing | ||
segments.append(np.column_stack((x, y))) | ||
|
||
# tiny proxy only if legend is requested | ||
if self.legend: | ||
proxy = mpl.lines.Line2D( | ||
[], | ||
[], | ||
color=kwds.get("color"), | ||
linewidth=kwds.get( | ||
"linewidth", mpl.rcParams["lines.linewidth"] | ||
), | ||
linestyle=kwds.get("linestyle", "-"), | ||
marker=kwds.get("marker"), | ||
) | ||
self._append_legend_handles_labels(proxy, label_str) | ||
else: | ||
newlines = plotf( | ||
ax, | ||
x, | ||
y, | ||
style=style, | ||
column_num=i, | ||
stacking_id=stacking_id, | ||
is_errorbar=is_errorbar, | ||
**kwds, | ||
) | ||
self._append_legend_handles_labels(newlines[0], label_str) | ||
|
||
errors = self._get_errorbars(label=label, index=i) | ||
kwds = dict(kwds, **errors) | ||
# reset x-limits for true ts plots | ||
if self._is_ts_plot(): | ||
lines = get_all_lines(ax) | ||
left, right = get_xlim(lines) | ||
ax.set_xlim(left, right) | ||
|
||
label = pprint_thing(label) | ||
label = self._mark_right_label(label, index=i) | ||
kwds["label"] = label | ||
|
||
newlines = plotf( | ||
ax, | ||
x, | ||
y, | ||
style=style, | ||
column_num=i, | ||
stacking_id=stacking_id, | ||
is_errorbar=is_errorbar, | ||
**kwds, | ||
# single draw call for fast path | ||
if use_collection and segments: | ||
if self.legend: | ||
lc_colors = [ | ||
cast(mpl.lines.Line2D, h).get_color() # mypy: h is Line2D | ||
for h in self.legend_handles | ||
] | ||
else: | ||
# no legend - repeat default colour cycle | ||
base = mpl.rcParams["axes.prop_cycle"].by_key()["color"] | ||
lc_colors = list(itertools.islice(itertools.cycle(base), len(segments))) | ||
|
||
lc = LineCollection( | ||
segments, | ||
colors=lc_colors, | ||
linewidths=self.kwds.get("linewidth", mpl.rcParams["lines.linewidth"]), | ||
) | ||
self._append_legend_handles_labels(newlines[0], label) | ||
|
||
if self._is_ts_plot(): | ||
# reset of xlim should be used for ts data | ||
# TODO: GH28021, should find a way to change view limit on xaxis | ||
lines = get_all_lines(ax) | ||
left, right = get_xlim(lines) | ||
ax.set_xlim(left, right) | ||
ax0 = self._get_ax(0) | ||
ax0.add_collection(lc) | ||
ax0.margins(0.05) | ||
|
||
# error: Signature of "_plot" incompatible with supertype "MPLPlot" | ||
@classmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
""" | ||
Ensure wide DataFrame.line plots use a single LineCollection | ||
instead of one Line2D per column (GH #61764). | ||
""" | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
import pandas as pd | ||
|
||
# Skip this entire module if matplotlib is not installed | ||
mpl = pytest.importorskip("matplotlib") | ||
plt = pytest.importorskip("matplotlib.pyplot") | ||
from matplotlib.collections import LineCollection | ||
|
||
|
||
def test_linecollection_used_for_wide_dataframe(): | ||
rng = np.random.default_rng(0) | ||
df = pd.DataFrame(rng.standard_normal((10, 201)).cumsum(axis=0)) | ||
|
||
ax = df.plot(legend=False) | ||
|
||
# exactly one LineCollection, and no Line2D artists | ||
assert sum(isinstance(c, LineCollection) for c in ax.collections) == 1 | ||
assert len(ax.lines) == 0 | ||
|
||
plt.close(ax.figure) |
Uh oh!
There was an error while loading. Please reload this page.