Open gdementen opened 3 years ago
@alixdamman: I am rather in favor of including this as a new kind of plot in Array.plot, as it can give a quick very-high level overview of what your data looks like. I would just rename x_axes and y_axes to simply x and y, to be consistent with the other plotting methods. What do you think?
@alixdamman: I am rather in favor of including this as a new kind of plot in Array.plot, as it can give a quick very-high level overview of what your data looks like. I would just rename x_axes and y_axes to simply x and y, to be consistent with the other plotting methods. What do you think?
Sounds a good idea to include this as a new kind of plot in Array.plot. I wonder if it would not be a good idea some kind of "gallery" in the LArray documentation (tutorial) ? In our experience, very few (none?) of our users take a look in the API Reference section.
The above code is very buggy regarding tick labels. I have been toying with it these days. Here is some buggy/work in progress code which supports matplotlib zoom:
class MaxNMultipleWithOffsetLocator(MaxNLocator):
def __init__(self, nbins=None, offset=0.5, **kwargs):
super().__init__(nbins, **kwargs)
self.offset = offset
def tick_values(self, vmin, vmax):
print(f"tick_values", vmin, vmax)
max_desired_ticks = self._nbins
# no + 1 because we place ticks in the middle
num_ticks = vmax - vmin
desired_numticks = min(num_ticks, max_desired_ticks)
print(f"{desired_numticks=}")
if desired_numticks < num_ticks:
step = np.ceil(num_ticks / desired_numticks)
print(f"{step=}")
return np.arange(int(vmin), int(vmax) + 1, step) + self.offset
else:
return np.arange(int(vmin), int(vmax) + 1) + self.offset
def __call__(self):
"""Return the locations of the ticks."""
vmin, vmax = self.axis.get_view_interval()
return self.tick_values(vmin, vmax)
def heatmap(arr : Array, y_axes=None, x_axes=None, numhaxes=1, axes_names=True, ax=None, **kwargs):
"""plot an ND array as a heatmap.
By default it uses the last array axis as the X axis and other array axes as Y axis (like the viewer table).
Only the first axis in each "direction" will have its name and labels shown.
Parameters
----------
arr : Array
data to display.
y_axes : int, str, Axis, tuple or AxisCollection, optional
axis or axes to use on the Y axis. Defaults to all array axes except the last `numhaxes` ones.
x_axes : int, str, Axis, tuple or AxisCollection, optional
axis or axes to use on the X axis. Defaults to all array axes except `y_axes`.
numhaxes : int, optional
if x_axes and y_axes are not specified, use the last numhaxes as X axes. Defaults to 1.
axes_names : bool, optional
whether or not to show axes names. Defaults to True
ax : matplotlib axes object, optional
**kwargs
any extra keyword argument is passed to pcolormesh. Likely of interest are cmap, vmin, vmax, antialiased or shading.
Returns
-------
matplotlib.AxesSubplot
"""
if arr.ndim < 2:
arr = arr.expand(Axis([''], ''))
if y_axes is None:
if x_axes is not None:
y_axes = arr.axes - x_axes
else:
y_axes = arr.axes[:-numhaxes]
else:
if isinstance(y_axes, str):
y_axes = [y_axes]
y_axes = arr.axes[y_axes]
if x_axes is None:
x_axes = arr.axes - y_axes
else:
if isinstance(x_axes, str):
x_axes = [x_axes]
x_axes = arr.axes[x_axes]
arr = arr.transpose(y_axes + x_axes).combine_axes([y_axes, x_axes])
# block size is the size of the other (non first) combined axes
x_block_size = int(x_axes[1:].size)
y_block_size = int(y_axes[1:].size)
if ax is None:
fig, ax = plt.subplots()
ax.pcolormesh(arr.data, **kwargs)
x_total_ticks = x_axes.size
y_total_ticks = y_axes.size
# place major ticks in the middle of blocks so that labels are centered
xticks = ax.get_xticks()
xlabels = x_axes[0].labels
ylabels = y_axes[0].labels
def format_x_tick(tick_val, tick_pos):
label_index = int(tick_val) // x_block_size
print(f"format_x_tick {tick_val=} {x_block_size=} {label_index=}")
if label_index < len(xlabels):
return xlabels[label_index]
else:
return '<bad>'
def format_y_tick(tick_val, tick_pos):
label_index = int(tick_val) // y_block_size
print(f"format_y_tick {tick_val=} {y_block_size=} {label_index=}")
if label_index < len(ylabels):
return ylabels[label_index]
else:
return '<bad>'
# A FuncFormatter is created automatically.
ax.xaxis.set_major_formatter(format_x_tick)
ax.yaxis.set_major_formatter(format_y_tick)
numticks = 11 # TODO: use this
x_locator = MaxNMultipleWithOffsetLocator(min(10, len(xlabels)), offset=x_block_size / 2)
y_locator = MaxNMultipleWithOffsetLocator(min(10, len(ylabels)), offset=y_block_size / 2)
ax.xaxis.set_major_locator(x_locator)
ax.yaxis.set_major_locator(y_locator)
xticks = ax.get_xticks()
print(f"new {xticks=}")
yticks = ax.get_yticks()
print(f"new {yticks=}")
print(f"{ylabels=}")
print(len(ylabels))
print(len(yticks))
# TODO: this makes y labels disappear, we should fix this
# ax.invert_yaxis()
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
# enable grid lines for minor ticks on axes when we have several "levels" for that axis
if len(x_axes) > 1:
# place minor ticks for grid lines between each block on the main axis
ax.set_xticks(np.arange(x_block_size, x_axes.size, x_block_size), minor=True)
ax.grid(True, axis='x', which='minor')
# hide ticks on x axis
ax.tick_params(axis='x', which='both', bottom=False, top=False)
if len(y_axes) > 1:
ax.set_yticks(np.arange(y_block_size, y_axes.size, y_block_size), minor=True)
ax.grid(True, axis='y', which='minor')
# hide ticks on y axis
ax.tick_params(axis='y', which='both', left=False, right=False)
# set axes names
if axes_names:
ax.set_xlabel('\n'.join(x_axes.names))
ax.set_ylabel('\n'.join(y_axes.names))
return ax
Here is some code I did for the IO team to draw heatmaps from larrays. Unsure whether or not it should be included in larray.