larray-project / larray

N-dimensional labelled arrays in Python
https://larray.readthedocs.io/
GNU General Public License v3.0
8 stars 6 forks source link

include heatmap visualization in Array.plot #930

Open gdementen opened 3 years ago

gdementen commented 3 years ago

Here is some code I did for the IO team to draw heatmaps from larrays. Unsure whether or not it should be included in larray.

def heatmap(arr : Array, y_axes=None, x_axes=None, numhaxes=1, axes_names=True, ax=None, **kwargs):
    """plot an ND array as a heatmap.

    By default it uses the last array axis as the X axis and other array axes as Y axis (like the viewer table). 
    Only the first axis in each "direction" will have its name and labels shown.

    Parameters
    ----------
    arr : Array
        data to display.
    y_axes : int, str, Axis, tuple or AxisCollection, optional
        axis or axes to use on the Y axis. Defaults to all array axes except the last `numhaxes` ones.
    x_axes : int, str, Axis, tuple or AxisCollection, optional
        axis or axes to use on the X axis. Defaults to all array axes except `y_axes`.
    numhaxes : int, optional
        if x_axes and y_axes are not specified, use the last numhaxes as X axes. Defaults to 1.
    axes_names : bool, optional
        whether or not to show axes names. Defaults to True
    ax : matplotlib axes object, optional 
    **kwargs 
        any extra keyword argument is passed to pcolormesh. Likely of interest are cmap, vmin, vmax, antialiased or shading.

    Returns
    -------
    matplotlib.AxesSubplot
    """
    if arr.ndim < 2:
        arr = arr.expand(Axis([''], ''))
    if y_axes is None:
        if x_axes is not None:
            y_axes = arr.axes - x_axes
        else:
            y_axes = arr.axes[:-numhaxes]
    else:
        if isinstance(y_axes, str):
            y_axes = [y_axes]
        y_axes = arr.axes[y_axes]
    if x_axes is None:
        x_axes = arr.axes - y_axes
    else:
        if isinstance(x_axes, str):
            x_axes = [x_axes]
        x_axes = arr.axes[x_axes]
    arr = arr.transpose(y_axes + x_axes).combine_axes([y_axes, x_axes])
    # block size is the size of the other (non first) combined axes 
    x_block_size = x_axes[1:].size
    y_block_size = y_axes[1:].size
    if ax is None:
        fig, ax = plt.subplots()
    ax.pcolormesh(arr, **kwargs)

    # place major ticks in the middle of blocks so that labels are centered
    xticks = ax.get_xticks()
    xlabels = x_axes[0].labels
    if len(xlabels) >= len(xticks):
        ax.set_xticks([t + x_block_size / 2 for t in xticks.astype(int) if t < len(xlabels)])
        ax.set_xticklabels([xlabels[t] for t in ax.get_xticks().astype(int)], rotation=0)
    else:
        ax.set_xticks(np.arange(0, x_axes.size, x_block_size) + x_block_size / 2)
        ax.set_xticklabels(xlabels, rotation=0)

    yticks = ax.get_yticks()
    ylabels = y_axes[0].labels
    if len(ylabels) >= len(yticks):
        ax.set_yticks([t + y_block_size / 2 for t in yticks.astype(int) if t < len(ylabels)])
        ax.set_yticklabels([ylabels[t] for t in ax.get_yticks().astype(int)], rotation=90, va='center')
    else:
        ax.set_yticks(np.arange(0, y_axes.size, y_block_size) + y_block_size / 2)
        ax.set_yticklabels(ylabels, rotation=90, va='center')

    ax.invert_yaxis()
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('top')

    # enable grid lines for minor ticks on axes when we have several "levels" for that axis
    if len(x_axes) > 1:
        # place minor ticks for grid lines between each block on the main axis
        ax.set_xticks(np.arange(x_block_size, x_axes.size, x_block_size), minor=True)
        ax.grid(True, axis='x', which='minor')
        # hide ticks on x axis
        ax.tick_params(axis='x', which='both', bottom=False, top=False)

    if len(y_axes) > 1:
        ax.set_yticks(np.arange(y_block_size, y_axes.size, y_block_size), minor=True)
        ax.grid(True, axis='y', which='minor')
        # hide ticks on y axis
        ax.tick_params(axis='y', which='both', left=False, right=False)

    # set axes names
    if axes_names:
        ax.set_xlabel(x_axes[0].name)
        ax.set_ylabel(y_axes[0].name)

    # hide ticks on both axes
    # ax.tick_params(which='both', bottom=False, top=False, left=False, right=False)
    return ax
gdementen commented 3 years ago

@alixdamman: I am rather in favor of including this as a new kind of plot in Array.plot, as it can give a quick very-high level overview of what your data looks like. I would just rename x_axes and y_axes to simply x and y, to be consistent with the other plotting methods. What do you think?

alixdamman commented 1 year ago

@alixdamman: I am rather in favor of including this as a new kind of plot in Array.plot, as it can give a quick very-high level overview of what your data looks like. I would just rename x_axes and y_axes to simply x and y, to be consistent with the other plotting methods. What do you think?

Sounds a good idea to include this as a new kind of plot in Array.plot. I wonder if it would not be a good idea some kind of "gallery" in the LArray documentation (tutorial) ? In our experience, very few (none?) of our users take a look in the API Reference section.

gdementen commented 1 month ago

The above code is very buggy regarding tick labels. I have been toying with it these days. Here is some buggy/work in progress code which supports matplotlib zoom:

class MaxNMultipleWithOffsetLocator(MaxNLocator):
    def __init__(self, nbins=None, offset=0.5, **kwargs):
        super().__init__(nbins, **kwargs)
        self.offset  = offset

    def tick_values(self, vmin, vmax):
        print(f"tick_values", vmin, vmax)
        max_desired_ticks = self._nbins
        # no + 1 because we place ticks in the middle
        num_ticks = vmax - vmin
        desired_numticks = min(num_ticks, max_desired_ticks)
        print(f"{desired_numticks=}")
        if desired_numticks < num_ticks:
            step = np.ceil(num_ticks / desired_numticks)
            print(f"{step=}")
            return np.arange(int(vmin), int(vmax) + 1, step) + self.offset
        else:
            return np.arange(int(vmin), int(vmax) + 1) + self.offset

    def __call__(self):
        """Return the locations of the ticks."""
        vmin, vmax = self.axis.get_view_interval()
        return self.tick_values(vmin, vmax)

def heatmap(arr : Array, y_axes=None, x_axes=None, numhaxes=1, axes_names=True, ax=None, **kwargs):
    """plot an ND array as a heatmap.

    By default it uses the last array axis as the X axis and other array axes as Y axis (like the viewer table). 
    Only the first axis in each "direction" will have its name and labels shown.

    Parameters
    ----------
    arr : Array
        data to display.
    y_axes : int, str, Axis, tuple or AxisCollection, optional
        axis or axes to use on the Y axis. Defaults to all array axes except the last `numhaxes` ones.
    x_axes : int, str, Axis, tuple or AxisCollection, optional
        axis or axes to use on the X axis. Defaults to all array axes except `y_axes`.
    numhaxes : int, optional
        if x_axes and y_axes are not specified, use the last numhaxes as X axes. Defaults to 1.
    axes_names : bool, optional
        whether or not to show axes names. Defaults to True
    ax : matplotlib axes object, optional 
    **kwargs 
        any extra keyword argument is passed to pcolormesh. Likely of interest are cmap, vmin, vmax, antialiased or shading.

    Returns
    -------
    matplotlib.AxesSubplot
    """
    if arr.ndim < 2:
        arr = arr.expand(Axis([''], ''))
    if y_axes is None:
        if x_axes is not None:
            y_axes = arr.axes - x_axes
        else:
            y_axes = arr.axes[:-numhaxes]
    else:
        if isinstance(y_axes, str):
            y_axes = [y_axes]
        y_axes = arr.axes[y_axes]
    if x_axes is None:
        x_axes = arr.axes - y_axes
    else:
        if isinstance(x_axes, str):
            x_axes = [x_axes]
        x_axes = arr.axes[x_axes]
    arr = arr.transpose(y_axes + x_axes).combine_axes([y_axes, x_axes])
    # block size is the size of the other (non first) combined axes 
    x_block_size = int(x_axes[1:].size)
    y_block_size = int(y_axes[1:].size)
    if ax is None:
        fig, ax = plt.subplots()
    ax.pcolormesh(arr.data, **kwargs)

    x_total_ticks = x_axes.size    
    y_total_ticks = y_axes.size

    # place major ticks in the middle of blocks so that labels are centered
    xticks = ax.get_xticks()
    xlabels = x_axes[0].labels
    ylabels = y_axes[0].labels

    def format_x_tick(tick_val, tick_pos):
        label_index = int(tick_val) // x_block_size
        print(f"format_x_tick {tick_val=} {x_block_size=} {label_index=}")
        if label_index < len(xlabels):
            return xlabels[label_index]
        else:
            return '<bad>'

    def format_y_tick(tick_val, tick_pos):
        label_index = int(tick_val) // y_block_size
        print(f"format_y_tick {tick_val=} {y_block_size=} {label_index=}")
        if label_index < len(ylabels):
            return ylabels[label_index]
        else:
            return '<bad>'

    # A FuncFormatter is created automatically.
    ax.xaxis.set_major_formatter(format_x_tick)
    ax.yaxis.set_major_formatter(format_y_tick)
    numticks = 11 # TODO: use this
    x_locator = MaxNMultipleWithOffsetLocator(min(10, len(xlabels)), offset=x_block_size / 2)
    y_locator = MaxNMultipleWithOffsetLocator(min(10, len(ylabels)), offset=y_block_size / 2)
    ax.xaxis.set_major_locator(x_locator)
    ax.yaxis.set_major_locator(y_locator)

    xticks = ax.get_xticks()
    print(f"new {xticks=}")
    yticks = ax.get_yticks()
    print(f"new {yticks=}")
    print(f"{ylabels=}")
    print(len(ylabels))
    print(len(yticks))

    # TODO: this makes y labels disappear, we should fix this
    # ax.invert_yaxis()
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position('top')

    # enable grid lines for minor ticks on axes when we have several "levels" for that axis
    if len(x_axes) > 1:
        # place minor ticks for grid lines between each block on the main axis
        ax.set_xticks(np.arange(x_block_size, x_axes.size, x_block_size), minor=True)
        ax.grid(True, axis='x', which='minor')
        # hide ticks on x axis
        ax.tick_params(axis='x', which='both', bottom=False, top=False)

    if len(y_axes) > 1:
        ax.set_yticks(np.arange(y_block_size, y_axes.size, y_block_size), minor=True)
        ax.grid(True, axis='y', which='minor')
        # hide ticks on y axis
        ax.tick_params(axis='y', which='both', left=False, right=False)

    # set axes names
    if axes_names:
        ax.set_xlabel('\n'.join(x_axes.names))
        ax.set_ylabel('\n'.join(y_axes.names))

    return ax