ezmsg-org / ezmsg-sigproc

Timeseries signal processing implementations for ezmsg
MIT License
0 stars 0 forks source link

sigproc.filter -- `fs` not needed, hangs when there is no design, and other issues #6

Open cboulay opened 6 months ago

cboulay commented 6 months ago

I'm working on a generator-based refactor of filter and butterworth filter. I'll attach the WIP functions in a reply to this top-post.

In trying to redo the existing classes using the generator function, I've come across a couple things that I don't understand and/or I think can be modified to simplify and generalize.

First, the class FilterSettingsBase includes the field fs. This isn't used as a "setting" at this base class level. While the child class ButterworthFilterSettings needs fs to design the filter (because cuton and cutoff are provided in Hz), the parent does not use fs except to check that it is not None before calling update_filter. This is misplaced because a child of this base class might want to initiate a filter with the coefficients already known and fs is irrelevant at design time. Indeed, FilterSettings has the field filt to store the coefficients.

Second, this block stalls out the pipeline if the filter is not designed:

        # Ensure filter is defined
        # TODO: Maybe have me be a passthrough filter until coefficients are received
        if self.STATE.filt is None:
            self.STATE.filt_set.clear()
            ez.logger.info("Awaiting filter coefficients...")
            await self.STATE.filt_set.wait()
            ez.logger.info("Filter coefficients received.")

I agree with the TODO that it should be a passthrough if there is no filter yet designed. IMO, it should be required that all the details necessary to design the filter should be provided on init + and possibly with the first non-empty input. If this requirement is enforced, then I expect it'll probably simplify some of the other workflows.

Third, Filter's design_filter implementation simply raises NotImplementedError. In my experience, having a method that raises that error is enough for linters to say "this is an abstract base class". However, Decimate instantiates a Filter() (manual b,a from a cheby1 design if factor > 1).

I think it would be better if Filter were truly abstract and instead there were a SimpleFilter child class that overrode design_filter with return None, and Decimate used SimpleFilter instead.

Fourth

I prefer "sos" filters. FilterCoefficients assumes "ba". What if FilterCoefficients' only field was coefs, with type Union[Tuple[np.ndarray], BACoeffs], where BACoeffs was the current version of FilterCoefficients? I think I'm missing a Tuple in there somewhere. But if done right, then it should be possible to use FilterCoefficients for either "ba" or "sos".

I think it's reasonable to do this in baby steps. I think there's a path to refactoring the classes before bringing in the generator functions. Unfortunately, I need the generator functions now to demonstrate offline and online harmony, but I can keep those in my project's namespace package for now.

cboulay commented 6 months ago
@consumer
def filtergen(
    axis: str, coefs: Optional[Tuple[np.ndarray]], coef_type: str
) -> Generator[AxisArray, AxisArray, None]:
    # Massage inputs
    if coefs is not None and not isinstance(coefs, tuple):
        # scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
        coefs = (coefs,)

    # Init IO
    axis_arr_in = AxisArray(np.array([]), dims=[""])
    axis_arr_out = AxisArray(np.array([]), dims=[""])

    filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
    zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]

    # State variables
    axis_idx = None
    zi = None
    expected_shape = None

    while True:
        axis_arr_in = yield axis_arr_out

        if coefs is None:
            # passthrough if we do not have a filter design.
            axis_arr_out = axis_arr_in
            continue

        if axis_idx is None:
            axis_name = axis_arr_in.dims[0] if axis is None else axis
            axis_idx = axis_arr_in.get_axis_idx(axis_name)

        dat_in = axis_arr_in.data

        # Re-calculate/reset zi if necessary
        samp_shape = dat_in.shape[:axis_idx] + dat_in.shape[axis_idx + 1 :]
        if zi is None or samp_shape != expected_shape:
            expected_shape = samp_shape
            n_tail = dat_in.ndim - axis_idx - 1
            zi = zi_func(*coefs)
            zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
            n_tile = dat_in.shape[:axis_idx] + (1,) + dat_in.shape[axis_idx + 1 :]
            if coef_type == "sos":
                # sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
                zi_expand = (slice(None),) + zi_expand
                n_tile = (1,) + n_tile
            zi = np.tile(zi[zi_expand], n_tile)

        dat_out, zi = filt_func(*coefs, dat_in, axis=axis_idx, zi=zi)
        axis_arr_out = replace(axis_arr_in, data=dat_out)
@consumer
def butter(
    axis: Optional[str],
    order: int = 0,
    cuton: Optional[float] = None,
    cutoff: Optional[float] = None,
    coef_type: str = "ba",
) -> Generator[AxisArray, AxisArray, None]:
    # IO
    axis_arr_in = AxisArray(np.array([]), dims=[""])
    axis_arr_out = AxisArray(np.array([]), dims=[""])

    btype, cutoffs = LegacyButterSettings(
        order=order, cuton=cuton, cutoff=cutoff
    ).filter_specs()

    # We cannot calculate coefs yet because we do not know input sample rate
    coefs = None
    filter_gen = filtergen(axis, coefs, coef_type)  # Passthrough.

    while True:
        axis_arr_in = yield axis_arr_out
        if coefs is None and order > 0:
            fs = 1 / axis_arr_in.axes[axis or axis_arr_in.dims[0]].gain
            coefs = scipy.signal.butter(
                order, Wn=cutoffs, btype=btype, fs=fs, output=coef_type
            )
            filter_gen = filtergen(axis, coefs, coef_type)

        axis_arr_out = filter_gen.send(axis_arr_in)

And I have unit tests to go with these.

griffinmilsap commented 5 months ago

You're amazing, Chad. I'm not fond of our filter inheritance structure but this refactor will do away with it entirely.

griffinmilsap commented 5 months ago

After having used this successfully a few times now, I'm going to deprecate our current filters and add these in as a new implementation. This will allow us to get rid of fs and make the changes we need in a backward compatible way.

cboulay commented 5 months ago

By coincidence I was working on this today but I'm happy to move onto the Spectrogram and Bandpower rework. Here is my test_butter.py. Feel free to change the folder layout however you see fit.

# from typing import Optional
import numpy as np
import pytest
import scipy.signal
from ezmsg.util.messages.axisarray import AxisArray

from ezmsg.sigproc.butterworthfilter import (
    ButterworthFilterSettings as LegacyButterSettings,
)
from ezmsg.sigproc.butter import butter, ButterworthFilterSettings

@pytest.mark.parametrize(
    "cutoff, cuton",
    [
        (30.0, None),  # lowpass
        (None, 30.0),  # highpass
        (45.0, 30.0),  # bandpass
        (30.0, 45.0),  # bandstop
    ],
)
@pytest.mark.parametrize("order", [2, 4, 8])
def test_butterworth_legacy_filter_settings(cutoff: float, cuton: float, order: int):
    """
    Test the butterworth legacy filter settings generation of btype and Wn.
    We test them explicitly because we assume they are correct when used in our later settings.

    Parameters:
        cutoff (float): The cutoff frequency for the filter. Can be None for highpass filters.
        cuton (float): The cuton frequency for the filter. Can be None for lowpass filters.
            If cuton is larger than cutoff we assume bandstop.
        order (int): The order of the filter.
    """
    settings_obj = LegacyButterSettings(
        axis="time", fs=500, order=order, cuton=cuton, cutoff=cutoff
    )
    btype, Wn = settings_obj.filter_specs()
    if cuton is None:
        assert btype == "lowpass"
        assert Wn == cutoff
    elif cutoff is None:
        assert btype == "highpass"
        assert Wn == cuton
    elif cuton <= cutoff:
        assert btype == "bandpass"
        assert Wn == (cuton, cutoff)
    else:
        assert btype == "bandstop"
        assert Wn == (cutoff, cuton)

@pytest.mark.parametrize(
    "cutoff, cuton",
    [
        (30.0, None),  # lowpass
        (None, 30.0),  # highpass
        (45.0, 30.0),  # bandpass
        (30.0, 45.0),  # bandstop
    ],
)
@pytest.mark.parametrize("order", [0, 2, 5, 8])  # 0 = passthrough
# All fs entries must be greater than 2x the largest of cutoff | cuton
@pytest.mark.parametrize("fs", [200.0])
@pytest.mark.parametrize("n_chans", [3])
@pytest.mark.parametrize("n_dims, time_ax", [(1, 0), (3, 0), (3, 1), (3, 2)])
@pytest.mark.parametrize("coef_type", ["ba", "sos"])
def test_butterworth(
    cutoff: float,
    cuton: float,
    order: int,
    fs: float,
    n_chans: int,
    n_dims: int,
    time_ax: int,
    coef_type: str,
):
    dur = 2.0
    n_freqs = 5
    n_splits = 4

    n_times = int(dur * fs)
    if n_dims == 1:
        dat_shape = [n_times]
        dat_dims = ["time"]
        other_axes = {}
    else:
        dat_shape = [n_freqs, n_chans]
        dat_shape.insert(time_ax, n_times)
        dat_dims = ["freq", "ch"]
        dat_dims.insert(time_ax, "time")
        other_axes = {"freq": AxisArray.Axis(unit="Hz"), "ch": AxisArray.Axis()}
    in_dat = np.arange(np.prod(dat_shape), dtype=float).reshape(*dat_shape)

    # Calculate Expected Result
    btype, Wn = LegacyButterSettings(
        axis="time", fs=500, order=order, cuton=cuton, cutoff=cutoff
    ).filter_specs()
    coefs = scipy.signal.butter(order, Wn, btype=btype, output=coef_type, fs=fs)
    tmp_dat = np.moveaxis(in_dat, time_ax, -1)
    if coef_type == "ba":
        if order == 0:
            # butter does not return correct coefs under these conditions; Set manually.
            coefs = (np.array([1.0, 0.0]),) * 2
        zi = scipy.signal.lfilter_zi(*coefs)
        if n_dims == 3:
            zi = np.tile(zi[None, None, :], (n_freqs, n_chans, 1))
        out_dat, _ = scipy.signal.lfilter(*coefs, tmp_dat, zi=zi)
    elif coef_type == "sos":
        zi = scipy.signal.sosfilt_zi(coefs)
        if n_dims == 3:
            zi = np.tile(zi[:, None, None, :], (1, n_freqs, n_chans, 1))
        out_dat, _ = scipy.signal.sosfilt(coefs, tmp_dat, zi=zi)
    out_dat = np.moveaxis(out_dat, -1, time_ax)

    # Split the data into multiple messages
    n_seen = 0
    messages = []
    for split_dat in np.array_split(in_dat, n_splits, axis=time_ax):
        _time_axis = AxisArray.Axis.TimeAxis(fs=fs, offset=n_seen / fs)
        messages.append(
            AxisArray(split_dat, dims=dat_dims, axes={**other_axes, "time": _time_axis})
        )
        n_seen += split_dat.shape[time_ax]

    # Test axis_name `None` when target axis idx is 0.
    axis_name = "time" if time_ax != 0 else None
    gen = butter(
        axis=axis_name,
        order=order,
        cuton=cuton,
        cutoff=cutoff,
        coef_type=coef_type,
    )

    result = np.concatenate([gen.send(_).data for _ in messages], axis=time_ax)
    assert np.allclose(result, out_dat)
griffinmilsap commented 5 months ago

Oh, nice! had you made significant progress toward this change already? I was just about to start trying to work it in, but if you have collected thoughts already I won't get in the way

cboulay commented 5 months ago

Not significant, no. And to be honest I was struggling a bit because I don't know what other apps out there you have that depend on the existing API. I'd much prefer if you took it over.

Here's one other snippet that might be helpful. I was using it at the top of filter generator with the expectation that I will reuse it. I'm not using this snippet anywhere so take it or leave it in the reimplementation as you see fit.

def _normalize_coefs(
        coefs: typing.Union[FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray],npt.NDArray]
) -> typing.Tuple[str, typing.Tuple[npt.NDArray,...]]:
    coef_type = "ba"
    if coefs is not None:
        # scipy.signal functions called with first arg `*coefs`.
        # Make sure we have a tuple of coefficients.
        if isinstance(coefs, npt.NDArray):
            coef_type = "sos"
            coefs = (coefs,)  # sos funcs just want a single ndarray.
        elif isinstance(coefs, FilterCoefficients):
            coefs = (FilterCoefficients.b, FilterCoefficients.a)
    return coef_type, coefs
griffinmilsap commented 5 months ago

Awesome, I'll massage this into sigproc

griffinmilsap commented 4 months ago

For lack of time, I've simply dropped these functions into ezmsg.sigproc.filter, but I'll be reworking/deprecating the old API soon(ish)