Skip to content

Commit 2533a61

Browse files
authored
Merge pull request #210 from scottshambaugh/xvals_datetime
Fix xvals for datetime axis
2 parents 2b14be6 + 04b81b0 commit 2533a61

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

Diff for: labellines/baseline/test_dateaxis_advanced.png

201 Bytes
Loading

Diff for: labellines/core.py

+19
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datetime import timedelta
44
import matplotlib.pyplot as plt
55
import numpy as np
6+
from datetime import datetime
67
from matplotlib.container import ErrorbarContainer
78
from matplotlib.dates import (
89
_SwitchableDateConverter,
@@ -204,18 +205,34 @@ def labelLines(
204205
if isinstance(xvals, tuple) and len(xvals) == 2:
205206
xmin, xmax = xvals
206207
xscale = ax.get_xscale()
208+
209+
# Convert datetime objects to numeric values for linspace/geomspace
210+
x_is_datetime = isinstance(xmin, datetime) or isinstance(xmax, datetime)
211+
if x_is_datetime:
212+
if not isinstance(xmin, datetime) or not isinstance(xmax, datetime):
213+
raise ValueError(
214+
f"Cannot mix datetime and numeric values in xvals: {xvals}"
215+
)
216+
xmin = plt.matplotlib.dates.date2num(xmin)
217+
xmax = plt.matplotlib.dates.date2num(xmax)
218+
207219
if xscale == "log":
208220
xvals = np.geomspace(xmin, xmax, len(all_lines) + 2)[1:-1]
209221
else:
210222
xvals = np.linspace(xmin, xmax, len(all_lines) + 2)[1:-1]
211223

224+
# Convert numeric values back to datetime objects
225+
if x_is_datetime:
226+
xvals = plt.matplotlib.dates.num2date(xvals)
227+
212228
# Build matrix line -> xvalue
213229
ok_matrix = np.zeros((len(all_lines), len(all_lines)), dtype=bool)
214230

215231
for i, line in enumerate(all_lines):
216232
xdata, _ = normalize_xydata(line)
217233
minx, maxx = min(xdata), max(xdata)
218234
for j, xv in enumerate(xvals): # type: ignore
235+
xv = line.convert_xunits(xv)
219236
ok_matrix[i, j] = minx < xv < maxx
220237

221238
# If some xvals do not fall in their corresponding line,
@@ -242,6 +259,8 @@ def labelLines(
242259
# Move xlabel if it is outside valid range
243260
xdata, _ = normalize_xydata(line)
244261
xmin, xmax = min(xdata), max(xdata)
262+
xv = line.convert_xunits(xv)
263+
245264
if not (xmin <= xv <= xmax):
246265
warnings.warn(
247266
(

Diff for: labellines/test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test_dateaxis_advanced(setup_mpl):
164164
ax.xaxis.set_major_locator(DayLocator())
165165
ax.xaxis.set_major_formatter(DateFormatter("%Y-%m-%d"))
166166

167-
labelLines(ax.get_lines())
167+
labelLines(ax.get_lines(), xvals=(dates[0], dates[-1]))
168168
return plt.gcf()
169169

170170

0 commit comments

Comments
 (0)