Open
Description
I made this larray-only solution for fun, based on the code I did for amg which used larray -> pandas dataframes -> pysankey2
There are a few TODO to include it in larray:
- fix problems when labels are not given in the same order in src vs dst
- check if the API is general enough (for example, holoviews.sankey uses 3 columns: from, to, weight and automatically places nodes horizontally (probably does a topological sort))
- document plot_curved_band (just description + arguments)
- document plot_boxes (just description + arguments)
- document sankey (including examples)
- test with an array with different axes names (it should work but I never tested)
- order labels from top to bottom
- optionally display "step" labels ("step1", "step2" in the example below)
- add support ax argument
- actually integrate it in larray
- add examples in the tutorial
Wish list:
- support to skip an intermediate step entirely
- support for different weights on left vs right
- allow giving explicit label order (in multistep, a user could want to order labels differently for each step)
- try to reproduce pysankey, alluvial and ggaluvial graphs and holoview sankey points (needs preceding wish list points I think)
# Inspired by https://github.com/vgalisson/pySankey and its parent repositories
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import larray as la
def curve(start, stop, nsteps: int = 100, slope_percent: float = 100.0):
"""
Parameters
----------
start: float
starting value
stop: float
stop value
nsteps: int, optional
Number of steps to go from start to stop. Must be >= 2. Defaults to 100.
slope_percent: int|float, optional
Percentage of steps used to go from start to stop. Other steps will have a
constant value. Defaults to 100.0
"""
if nsteps < 2:
raise ValueError("steps must be >= 2")
n_convolve = 2
kernel_size = int(nsteps * slope_percent / 100 / 2)
# increase array size because each convolve "strips" kernel_size - 1 points
size = nsteps + (kernel_size - 1) * n_convolve
arr = np.empty(size)
arr[:size // 2] = start
if size % 2 == 0:
arr[size // 2:] = stop
else:
arr[size // 2:] = (start + stop) / 2
arr[(size // 2) + 1:] = stop
w = np.full(kernel_size, 1 / kernel_size)
for _ in range(n_convolve):
arr = np.convolve(arr, w, mode='valid')
return arr
def plot_curved_band(ax, x, left_bottom, left_top, right_bottom, right_top,
nsteps=100, slope_percent=100, **kwargs):
bottoms = curve(left_bottom, right_bottom, nsteps=nsteps, slope_percent=slope_percent)
tops = curve(left_top, right_top, nsteps=nsteps, slope_percent=slope_percent)
ax.fill_between(x=x, y1=bottoms, y2=tops, **kwargs)
def plot_boxes(ax, box_width, boxsep, colors, left, right,
dist_to_box_bottom, dist_to_box_left, weights, box_kws, text_kws):
assert weights.ndim == 1
weights = weights[weights > 0]
axis = weights.axes[0]
bottoms = (weights + boxsep).cumsum(axis) \
.prepend(axis, 0, label="dummy") \
.shift(axis)
for label in axis:
color = colors[label]
bottom = bottoms[label]
box_height = weights[label]
# boxes
ax.fill_between(x=[left, right], y1=bottom, y2=bottom + box_height, facecolor=color, **box_kws)
# box labels
ax.text(left + box_width * dist_to_box_left,
bottom + box_height * dist_to_box_bottom,
label.eval(),
{'ha': 'right', 'va': 'center'},
**text_kws)
def sankey(weights, box_width=2, strip_width=10, nsteps=20, boxsep=0.1, text_kws=None,
dist_to_box_left=-0.15, dist_to_box_bottom=0.5, colors=None,
cmap='tab10', band_kws=None, box_kws=None, figsize=(10, 10), src_axis=None, dst_axis=None, step_axis=None,
strip_shrink=0.06):
if step_axis is None:
if weights.ndim == 3:
step_axis = weights.axes[0]
else:
step_axis = la.Axis("step=step1")
weights = weights.expand(step_axis)
else:
step_axis = weights.axes[step_axis]
if src_axis is None:
src_axis = (weights.axes - step_axis)[0]
else:
src_axis = weights.axes[src_axis]
if dst_axis is None:
dst_axis = (weights.axes - step_axis - src_axis)[0]
else:
dst_axis = weights.axes[dst_axis]
if band_kws is None:
band_kws = {}
if 'alpha' not in band_kws:
band_kws['alpha'] = 0.4
if box_kws is None:
box_kws = {}
if 'alpha' not in box_kws:
box_kws['alpha'] = 0.8
if text_kws is None:
text_kws = {}
if 'fontsize' not in text_kws:
text_kws['fontsize'] = 18
src_axis = weights.axes[src_axis]
dst_axis = weights.axes[dst_axis]
if colors is None:
all_labels = src_axis.union(dst_axis).labels
if isinstance(cmap, str):
cmap = matplotlib.cm.get_cmap(cmap)
assert isinstance(cmap, matplotlib.colors.Colormap)
colors = la.stack({label: cmap(i) for i, label in enumerate(all_labels)}, 'label')
box = la.Axis(len(step_axis) + 1, "box")
box_left = la.sequence(box, inc=box_width + strip_width)
box_right = box_left + box_width
left_weight = weights.sum(dst_axis).rename(src_axis, 'label')
right_weight = weights.sum(src_axis).rename(dst_axis, 'label')
fig = plt.figure(figsize=figsize)
ax = fig.subplots(1)
strip_left = box_right
strip_right = box_left.shift(box_left.axes[0], n=-1)
for i, step in enumerate(step_axis):
plot_boxes(ax, box_width, boxsep, colors, box_left.i[i], box_right.i[i],
dist_to_box_bottom, dist_to_box_left, left_weight[step], box_kws, text_kws)
step_left_weight = left_weight[step]
step_left_weight = step_left_weight[step_left_weight > 0]
step_box_left_bottom = (step_left_weight + boxsep).cumsum('label') \
.prepend("label", 0, label="dummy") \
.shift("label")
x = np.linspace(strip_left.i[i], strip_right.i[i], nsteps)
step_right_weight = right_weight[step]
step_right_weight = step_right_weight[step_right_weight > 0]
step_box_right_bottom = (step_right_weight + boxsep).cumsum('label') \
.prepend("label", 0, label="dummy") \
.shift("label")
# strips
strip_right_y0_per_label = step_box_right_bottom + strip_shrink / 2
for src_label in src_axis:
color = colors[src_label]
src_label_box_bottom = step_box_left_bottom[src_label]
strip_left_y0 = src_label_box_bottom + strip_shrink / 2
weights_for_source = weights[step, src_label]
num_strips_for_source = (weights_for_source > 0).sum()
for dst_label in dst_axis:
weight = weights_for_source[dst_label]
weights_for_dest = weights[step, dst_label]
num_strips_for_dest = (weights_for_dest > 0).sum()
if weight > 0:
strip_left_y1 = strip_left_y0 + weight - strip_shrink / num_strips_for_source
strip_right_y0 = strip_right_y0_per_label[dst_label]
strip_right_y1 = strip_right_y0 + weight - strip_shrink / num_strips_for_dest
plot_curved_band(ax, x,
strip_left_y0, strip_left_y1,
strip_right_y0, strip_right_y1,
color=color, nsteps=nsteps, **band_kws)
strip_left_y0 += weight - strip_shrink / num_strips_for_source
strip_right_y0_per_label[dst_label] += weight - strip_shrink / num_strips_for_dest
# boxes
step = step_axis.i[-1]
plot_boxes(ax, box_width, boxsep, colors, box_left.i[-1], box_right.i[-1],
dist_to_box_bottom, dist_to_box_left, right_weight[step], box_kws, text_kws)
ax.axis('off')
# works
weight_arr = la.from_string(r"""
step src\dst a b c d
step1 a 1 1 0 0
step1 b 1 0 1 0
step1 c 1 1 0 0
step2 a 0 1 1 1
step2 b 1 0 0 1
step2 c 1 0 0 0
""")
# fails (bad labels)
# weight_arr = la.from_string(r"""
# step src\dst a c b d
# step1 a 1 1 0 0
# step1 b 1 0 1 0
# step1 c 1 1 0 0
# step2 a 0 1 1 1
# step2 b 1 0 0 1
# step2 c 1 0 0 0
# """)
# fails (everything is odd)
# weight_arr = la.from_string(r"""
# step src\dst c e b d
# step1 b 1 1 0 0
# step1 c 0 1 1 0
# step1 e 0 1 1 0
# step2 b 0 1 0 1
# step2 c 0 1 0 0
# step2 e 1 0 1 1
# """)
# sankey(weights)
# lets go funky
sankey(weight_arr, band_kws={'linestyles': 'dashed', 'hatch': '/', 'edgecolor': 'black'})
plt.show()