larray-project / larray

N-dimensional labelled arrays in Python
https://larray.readthedocs.io/
GNU General Public License v3.0
8 stars 6 forks source link

include sankey diagrams as plotting option #1031

Open gdementen opened 1 year ago

gdementen commented 1 year ago

I made this larray-only solution for fun, based on the code I did for amg which used larray -> pandas dataframes -> pysankey2

There are a few TODO to include it in larray:

Wish list:


# Inspired by https://github.com/vgalisson/pySankey and its parent repositories

import matplotlib
import matplotlib.pyplot as plt

import numpy as np
import larray as la

def curve(start, stop, nsteps: int = 100, slope_percent: float = 100.0):
    """
    Parameters
    ----------
    start: float
        starting value
    stop: float
        stop value
    nsteps: int, optional
        Number of steps to go from start to stop. Must be >= 2. Defaults to 100.
    slope_percent: int|float, optional
        Percentage of steps used to go from start to stop. Other steps will have a
        constant value. Defaults to 100.0
    """
    if nsteps < 2:
        raise ValueError("steps must be >= 2")
    n_convolve = 2
    kernel_size = int(nsteps * slope_percent / 100 / 2)
    # increase array size because each convolve "strips" kernel_size - 1 points
    size = nsteps + (kernel_size - 1) * n_convolve

    arr = np.empty(size)
    arr[:size // 2] = start
    if size % 2 == 0:
        arr[size // 2:] = stop
    else:
        arr[size // 2:] = (start + stop) / 2
        arr[(size // 2) + 1:] = stop

    w = np.full(kernel_size, 1 / kernel_size)
    for _ in range(n_convolve):
        arr = np.convolve(arr, w, mode='valid')
    return arr

def plot_curved_band(ax, x, left_bottom, left_top, right_bottom, right_top,
                     nsteps=100, slope_percent=100, **kwargs):
    bottoms = curve(left_bottom, right_bottom, nsteps=nsteps, slope_percent=slope_percent)
    tops = curve(left_top, right_top, nsteps=nsteps, slope_percent=slope_percent)
    ax.fill_between(x=x, y1=bottoms, y2=tops, **kwargs)

def plot_boxes(ax, box_width, boxsep, colors, left, right,
               dist_to_box_bottom, dist_to_box_left, weights, box_kws, text_kws):
    assert weights.ndim == 1
    weights = weights[weights > 0]
    axis = weights.axes[0]
    bottoms = (weights + boxsep).cumsum(axis) \
                                .prepend(axis, 0, label="dummy") \
                                .shift(axis)
    for label in axis:
        color = colors[label]
        bottom = bottoms[label]
        box_height = weights[label]

        # boxes
        ax.fill_between(x=[left, right], y1=bottom, y2=bottom + box_height, facecolor=color, **box_kws)

        # box labels
        ax.text(left + box_width * dist_to_box_left,
                bottom + box_height * dist_to_box_bottom,
                label.eval(),
                {'ha': 'right', 'va': 'center'},
                **text_kws)

def sankey(weights, box_width=2, strip_width=10, nsteps=20, boxsep=0.1, text_kws=None,
           dist_to_box_left=-0.15, dist_to_box_bottom=0.5, colors=None,
           cmap='tab10', band_kws=None, box_kws=None, figsize=(10, 10), src_axis=None, dst_axis=None, step_axis=None,
           strip_shrink=0.06):
    if step_axis is None:
        if weights.ndim == 3:
            step_axis = weights.axes[0]
        else:
            step_axis = la.Axis("step=step1")
            weights = weights.expand(step_axis)
    else:
        step_axis = weights.axes[step_axis]
    if src_axis is None:
        src_axis = (weights.axes - step_axis)[0]
    else:
        src_axis = weights.axes[src_axis]

    if dst_axis is None:
        dst_axis = (weights.axes - step_axis - src_axis)[0]
    else:
        dst_axis = weights.axes[dst_axis]
    if band_kws is None:
        band_kws = {}
    if 'alpha' not in band_kws:
        band_kws['alpha'] = 0.4

    if box_kws is None:
        box_kws = {}
    if 'alpha' not in box_kws:
        box_kws['alpha'] = 0.8

    if text_kws is None:
        text_kws = {}
    if 'fontsize' not in text_kws:
        text_kws['fontsize'] = 18

    src_axis = weights.axes[src_axis]
    dst_axis = weights.axes[dst_axis]

    if colors is None:
        all_labels = src_axis.union(dst_axis).labels
        if isinstance(cmap, str):
            cmap = matplotlib.cm.get_cmap(cmap)
        assert isinstance(cmap, matplotlib.colors.Colormap)
        colors = la.stack({label: cmap(i) for i, label in enumerate(all_labels)}, 'label')

    box = la.Axis(len(step_axis) + 1, "box")
    box_left = la.sequence(box, inc=box_width + strip_width)
    box_right = box_left + box_width

    left_weight = weights.sum(dst_axis).rename(src_axis, 'label')
    right_weight = weights.sum(src_axis).rename(dst_axis, 'label')

    fig = plt.figure(figsize=figsize)
    ax = fig.subplots(1)

    strip_left = box_right
    strip_right = box_left.shift(box_left.axes[0], n=-1)
    for i, step in enumerate(step_axis):
        plot_boxes(ax, box_width, boxsep, colors, box_left.i[i], box_right.i[i],
                   dist_to_box_bottom, dist_to_box_left, left_weight[step], box_kws, text_kws)

        step_left_weight = left_weight[step]
        step_left_weight = step_left_weight[step_left_weight > 0]
        step_box_left_bottom = (step_left_weight + boxsep).cumsum('label') \
                                                          .prepend("label", 0, label="dummy") \
                                                          .shift("label")

        x = np.linspace(strip_left.i[i], strip_right.i[i], nsteps)

        step_right_weight = right_weight[step]
        step_right_weight = step_right_weight[step_right_weight > 0]
        step_box_right_bottom = (step_right_weight + boxsep).cumsum('label') \
                                                            .prepend("label", 0, label="dummy") \
                                                            .shift("label")

        # strips
        strip_right_y0_per_label = step_box_right_bottom + strip_shrink / 2
        for src_label in src_axis:
            color = colors[src_label]
            src_label_box_bottom = step_box_left_bottom[src_label]

            strip_left_y0 = src_label_box_bottom + strip_shrink / 2
            weights_for_source = weights[step, src_label]
            num_strips_for_source = (weights_for_source > 0).sum()
            for dst_label in dst_axis:
                weight = weights_for_source[dst_label]
                weights_for_dest = weights[step, dst_label]
                num_strips_for_dest = (weights_for_dest > 0).sum()
                if weight > 0:
                    strip_left_y1 = strip_left_y0 + weight - strip_shrink / num_strips_for_source
                    strip_right_y0 = strip_right_y0_per_label[dst_label]
                    strip_right_y1 = strip_right_y0 + weight - strip_shrink / num_strips_for_dest
                    plot_curved_band(ax, x,
                                     strip_left_y0, strip_left_y1,
                                     strip_right_y0, strip_right_y1,
                                     color=color, nsteps=nsteps, **band_kws)

                    strip_left_y0 += weight - strip_shrink / num_strips_for_source
                    strip_right_y0_per_label[dst_label] += weight - strip_shrink / num_strips_for_dest

    # boxes
    step = step_axis.i[-1]

    plot_boxes(ax, box_width, boxsep, colors, box_left.i[-1], box_right.i[-1],
               dist_to_box_bottom, dist_to_box_left, right_weight[step], box_kws, text_kws)

    ax.axis('off')

# works
weight_arr = la.from_string(r"""
 step  src\dst  a  b  c  d
step1        a  1  1  0  0
step1        b  1  0  1  0
step1        c  1  1  0  0
step2        a  0  1  1  1
step2        b  1  0  0  1
step2        c  1  0  0  0
""")

# fails (bad labels)
# weight_arr = la.from_string(r"""
#  step  src\dst  a  c  b  d
# step1        a  1  1  0  0
# step1        b  1  0  1  0
# step1        c  1  1  0  0
# step2        a  0  1  1  1
# step2        b  1  0  0  1
# step2        c  1  0  0  0
# """)

# fails (everything is odd)
# weight_arr = la.from_string(r"""
# step  src\dst  c  e  b  d
# step1       b  1  1  0  0
# step1       c  0  1  1  0
# step1       e  0  1  1  0
# step2       b  0  1  0  1
# step2       c  0  1  0  0
# step2       e  1  0  1  1
# """)

# sankey(weights)
# lets go funky
sankey(weight_arr, band_kws={'linestyles': 'dashed', 'hatch': '/', 'edgecolor': 'black'})
plt.show()
alixdamman commented 1 year ago

I think including this kind of plots in LArray would be plus and appreciated by at least one team of our office. In the TODO list I would add "show examples in the Plot section of the Tutorial".

You can remove the question tag

gdementen commented 1 year ago

I have started working on this in the train. No promises I will not get bored before finishing it though.