ezmsg-org / ezmsg

Pure-Python DAG-based high-performance SHM-backed pub-sub and multi-processing pattern
https://ezmsg.readthedocs.io/en/latest/
MIT License
15 stars 6 forks source link

resample_axes functionality for AxisArray #150

Closed mangrick closed 3 weeks ago

mangrick commented 1 month ago

I added a function that lets you resample more conveniently the axes of an AxisArray object. This can be particularly helpful if the message class is inheriting from AxisArray but where the unit is only exposed to the AxisArray interface. Having a resampling function in AxisArray reduces a lot of repeating code to reconfigure the necessary axes in the units. Example:

async def process(self, msg: AxisArray) -> AsyncGenerator:
    # ...
    yield self.OUTPUT, replace(msg, data=data).resample_axes(Time=500)
griffinmilsap commented 4 weeks ago

Hey @mangrick!

First off, thanks for your contribution! We're excited to have an influx of processing units and a resample unit would be very welcome! Especially an efficient one that handles non-integer resampling (as an integer upsample followed by an integer downsample).

As for this PR, I think the simpler way to handle this is with a simple replace. This particular implementation overwrites/ignores the units in the existing axis, and I think the name of the method is a little misleading if it doesn't actually resample the data in the message. Also, if it did resample the data it'd probably belong in the sigproc module moreso than here. I do admit that AxisArrays are currently a little painful to work with, especially when an operation changes an axis or creates a new axis. Thankfully, in the sigproc module we're starting to round out the basic set of manipulations one might want to do with a time axis.

Here's an example of how to accomplish your particular ends using replace:

import numpy as np
from dataclasses import replace
from ezmsg.util.messages.axisarray import AxisArray

aa = AxisArray(
    data = np.arange(2 * 3 * 4).reshape((2, 3, 4)),
    dims = ['a', 'b', 'time'],
    axes = {
        'time': AxisArray.Axis.TimeAxis(fs = 200, offset = 20)
    }
)

# This assumes 'time' is the last dimension. This could be made consistently the case with AxisArray.transpose
# for a more general implementation check ezmsg.sigproc.downsample
q = 2
aa_down = replace(aa,
    data = aa.data[..., ::q],
    axes = {
        k: replace(v, gain = v.gain * q) 
        if k == 'time' else v 
        for k, v in aa.axes.items()
    }
)

print(f'{aa.shape=}, {aa.axes["time"]}')
print(f'{aa_down.shape=}, {aa_down.axes["time"]}')

Output:

aa.shape=(2, 3, 4), AxisArray.Axis(unit='s', gain=0.005, offset=20)
aa_down.shape=(2, 3, 2), AxisArray.Axis(unit='s', gain=0.01, offset=20)
griffinmilsap commented 3 weeks ago

I'm going to close this out; please reach out if you think I missed the point of this PR with my recommendation! Thank you for submitting a well-structured PR with tests and docstrings though!