Skip to content

Commit 558fee9

Browse files
author
Niru Nahesh
committed
Preserve lazy color iteration in ridgeline
1 parent fa8ba84 commit 558fee9

2 files changed

Lines changed: 24 additions & 11 deletions

File tree

src/jetplot/plots.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,31 +338,34 @@ def ridgeline(
338338
339339
Args:
340340
t: Grid used when evaluating the kernel density estimate.
341-
xs: Iterable of 1-D samples. Accepts generators but consumes them eagerly.
342-
colors: Iterable of colors, one for each series in ``xs``.
341+
xs: Iterable of 1-D samples. Accepts generators and consumes them once.
342+
colors: Iterable of colors. Must provide at least as many entries as ``xs``.
343343
edgecolor: Line color used for the outline.
344344
ymax: Upper y-limit for each subplot.
345345
346346
Raises:
347-
ValueError: If ``xs`` is empty or the number of colors does not match.
347+
ValueError: If ``xs`` is empty or ``colors`` provides too few values.
348348
"""
349349
fig = kwargs["fig"]
350350
xs_list = list(xs)
351-
color_list = list(colors)
351+
colors_iter = iter(colors)
352352

353353
if not xs_list:
354354
raise ValueError("xs must contain at least one series.")
355-
if len(xs_list) != len(color_list):
356-
raise ValueError("xs and colors must have the same length.")
357355

358356
axs = []
359357

360-
for k, (x, c) in enumerate(zip(xs_list, color_list, strict=True)):
358+
for k, x in enumerate(xs_list):
359+
try:
360+
palette_color = next(colors_iter)
361+
except StopIteration as exc:
362+
raise ValueError("colors must provide at least as many items as xs.") from exc
363+
361364
ax = fig.add_subplot(len(xs_list), 1, k + 1)
362365
y = gaussian_kde(x).evaluate(t)
363-
ax.fill_between(t, y, color=c, clip_on=False)
366+
ax.fill_between(t, y, color=palette_color, clip_on=False)
364367
ax.plot(t, y, color=edgecolor, clip_on=False)
365-
ax.axhline(0.0, lw=2, color=c, clip_on=False)
368+
ax.axhline(0.0, lw=2, color=palette_color, clip_on=False)
366369

367370
ax.set_xlim(t[0], t[-1])
368371
ax.set_xticks([])

tests/test_plots.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,25 @@ def test_ridgeline_accepts_generators():
110110

111111
def test_ridgeline_mismatched_lengths_raise():
112112
t = np.linspace(-3, 3, 10)
113-
xs = [np.linspace(0, 1, 5)]
114-
colors = (color for color in plots.neutral[:2])
113+
xs = [np.linspace(0, 1, 5), np.linspace(0, 2, 5)]
114+
colors = (color for color in plots.neutral[:1])
115115

116116
with pytest.raises(ValueError):
117117
plots.ridgeline(t, xs=xs, colors=colors)
118118

119119
plt.close("all")
120120

121121

122+
def test_ridgeline_allows_extra_colors():
123+
rng = np.random.default_rng(2)
124+
t = np.linspace(-3, 3, 25)
125+
xs = [rng.standard_normal(100) for _ in range(3)]
126+
127+
fig, axs = plots.ridgeline(t, xs=xs, colors=plots.neutral)
128+
assert len(axs) == 3
129+
plt.close(fig)
130+
131+
122132
def test_ellipse_returns_patch():
123133
rng = np.random.default_rng(1)
124134
x = rng.standard_normal(200)

0 commit comments

Comments
 (0)