Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for labels along lines #131

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 69 additions & 41 deletions labellines/core.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
import warnings
from collections.abc import Iterable
from typing import List, Literal, Optional, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.container import ErrorbarContainer
from matplotlib.dates import DateConverter, num2date
from more_itertools import always_iterable

from .line_label import LineLabel
from .line_label import CurvedLineLabel, LineLabel
from .utils import ensure_float, maximum_bipartite_matching


# Label line with line2D label data
def labelLine(
line,
x,
label=None,
align=True,
drop_label=False,
yoffset=0,
yoffset_logspace=False,
outline_color="auto",
outline_width=8,
line: plt.Line2D,
x: float,
curved_text: bool = False,
label: Optional[str] = None,
align: bool = True,
drop_label: bool = False,
yoffset: float = 0,
yoffset_logspace: bool = False,
outline_color: Union[Literal["auto"], None, "str"] = "auto",
outline_width: float = 8,
**kwargs,
):
"""
Expand All @@ -32,6 +35,8 @@ def labelLine(
The line holding the label
x : number
The location in data unit of the label
curved_text : bool, optional
If True, the label will be curved to follow the line.
label : string, optional
The label to set. This is inferred from the line by default
drop_label : bool, optional
Expand All @@ -51,18 +56,32 @@ def labelLine(
Optional arguments passed to ax.text
"""

label = label or line.get_label()

try:
txt = LineLabel(
line,
x,
label=label,
align=align,
yoffset=yoffset,
yoffset_logspace=yoffset_logspace,
outline_color=outline_color,
outline_width=outline_width,
**kwargs,
)
if curved_text:
txt = CurvedLineLabel(
line,
label=label,
axes=line.axes,
yoffset=yoffset,
yoffset_logspace=yoffset_logspace,
outline_color=outline_color,
outline_width=outline_width,
**kwargs,
)
else:
txt = LineLabel(
line,
x,
label=label,
align=align,
yoffset=yoffset,
yoffset_logspace=yoffset_logspace,
outline_color=outline_color,
outline_width=outline_width,
**kwargs,
)
except ValueError as err:
if "does not have a well defined value" in str(err):
warnings.warn(
Expand All @@ -84,14 +103,15 @@ def labelLine(


def labelLines(
lines=None,
align=True,
xvals=None,
drop_label=False,
shrink_factor=0.05,
yoffsets=0,
outline_color="auto",
outline_width=5,
lines: Optional[List[plt.Line2D]] = None,
align: bool = True,
xvals: Union[None, Tuple[float, float], Iterable[float]] = None,
curved_text: bool = False,
drop_label: bool = False,
shrink_factor: float = 0.05,
yoffsets: Union[float, Iterable[float]] = 0,
outline_color: Union[Literal["auto"], None, "str"] = "auto",
outline_width: float = 5,
**kwargs,
):
"""Label all lines with their respective legends.
Expand All @@ -106,6 +126,8 @@ def labelLines(
xvals : (xfirst, xlast) or array of float, optional
The location of the labels. If a tuple, the labels will be
evenly spaced between xfirst and xlast (in the axis units).
curved_text : bool, optional
If True, the labels will be curved to follow the line.
drop_label : bool, optional
If True, the label is consumed by the function so that subsequent
calls to e.g. legend do not use it anymore.
Expand Down Expand Up @@ -157,9 +179,9 @@ def labelLines(
# to generate them.
if xvals is None:
xvals = ax.get_xlim()
xvals_rng = xvals[1] - xvals[0]
xvals_rng = xvals[1] - xvals[0] # type: ignore
shrinkage = xvals_rng * shrink_factor
xvals = (xvals[0] + shrinkage, xvals[1] - shrinkage)
xvals = (xvals[0] + shrinkage, xvals[1] - shrinkage) # type: ignore

if isinstance(xvals, tuple) and len(xvals) == 2:
xmin, xmax = xvals
Expand All @@ -177,7 +199,7 @@ def labelLines(
for i, line in enumerate(all_lines):
xdata = ensure_float(line.get_xdata())
minx, maxx = min(xdata), max(xdata)
for j, xv in enumerate(xvals):
for j, xv in enumerate(xvals): # type: ignore
ok_matrix[i, j] = minx < xv < maxx

# If some xvals do not fall in their corresponding line,
Expand All @@ -189,14 +211,14 @@ def labelLines(
order[order < 0] = np.setdiff1d(np.arange(len(order)), order[order >= 0])

# Now reorder the xvalues
old_xvals = xvals.copy()
xvals[order] = old_xvals
old_xvals = xvals.copy() # type: ignore
xvals[order] = old_xvals # type: ignore
else:
xvals = list(always_iterable(xvals)) # force the creation of a copy

lab_lines, labels = [], []
# Take only the lines which have labels other than the default ones
for i, (line, xv) in enumerate(zip(all_lines, xvals)):
for i, (line, xv) in enumerate(zip(all_lines, xvals)): # type: ignore
label = all_labels[all_lines.index(line)]
lab_lines.append(line)
labels.append(label)
Expand All @@ -215,18 +237,24 @@ def labelLines(
stacklevel=1,
)
new_xv = min(xdata) + (max(xdata) - min(xdata)) * 0.9
xvals[i] = new_xv
xvals[i] = new_xv # type: ignore

# Convert float values back to datetime in case of datetime axis
if isinstance(ax.xaxis.converter, DateConverter):
xvals = [num2date(x).replace(tzinfo=ax.xaxis.get_units()) for x in xvals]
tz = ax.xaxis.get_units()
xvals = [num2date(x).replace(tzinfo=tz) for x in xvals] # type: ignore

txts = []
try:

if not isinstance(yoffsets, Iterable):
yoffsets = [float(yoffsets)] * len(all_lines)
except TypeError:
pass
for line, x, yoffset, label in zip(lab_lines, xvals, yoffsets, labels):

for line, x, yoffset, label in zip(
lab_lines,
xvals, # type: ignore
yoffsets,
labels,
):
txts.append(
labelLine(
line,
Expand Down
Loading