Open cboulay opened 6 months ago
@consumer
def filtergen(
axis: str, coefs: Optional[Tuple[np.ndarray]], coef_type: str
) -> Generator[AxisArray, AxisArray, None]:
# Massage inputs
if coefs is not None and not isinstance(coefs, tuple):
# scipy.signal functions called with first arg `*coefs`, but sos coefs are a single ndarray.
coefs = (coefs,)
# Init IO
axis_arr_in = AxisArray(np.array([]), dims=[""])
axis_arr_out = AxisArray(np.array([]), dims=[""])
filt_func = {"ba": scipy.signal.lfilter, "sos": scipy.signal.sosfilt}[coef_type]
zi_func = {"ba": scipy.signal.lfilter_zi, "sos": scipy.signal.sosfilt_zi}[coef_type]
# State variables
axis_idx = None
zi = None
expected_shape = None
while True:
axis_arr_in = yield axis_arr_out
if coefs is None:
# passthrough if we do not have a filter design.
axis_arr_out = axis_arr_in
continue
if axis_idx is None:
axis_name = axis_arr_in.dims[0] if axis is None else axis
axis_idx = axis_arr_in.get_axis_idx(axis_name)
dat_in = axis_arr_in.data
# Re-calculate/reset zi if necessary
samp_shape = dat_in.shape[:axis_idx] + dat_in.shape[axis_idx + 1 :]
if zi is None or samp_shape != expected_shape:
expected_shape = samp_shape
n_tail = dat_in.ndim - axis_idx - 1
zi = zi_func(*coefs)
zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail
n_tile = dat_in.shape[:axis_idx] + (1,) + dat_in.shape[axis_idx + 1 :]
if coef_type == "sos":
# sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop)
zi_expand = (slice(None),) + zi_expand
n_tile = (1,) + n_tile
zi = np.tile(zi[zi_expand], n_tile)
dat_out, zi = filt_func(*coefs, dat_in, axis=axis_idx, zi=zi)
axis_arr_out = replace(axis_arr_in, data=dat_out)
@consumer
def butter(
axis: Optional[str],
order: int = 0,
cuton: Optional[float] = None,
cutoff: Optional[float] = None,
coef_type: str = "ba",
) -> Generator[AxisArray, AxisArray, None]:
# IO
axis_arr_in = AxisArray(np.array([]), dims=[""])
axis_arr_out = AxisArray(np.array([]), dims=[""])
btype, cutoffs = LegacyButterSettings(
order=order, cuton=cuton, cutoff=cutoff
).filter_specs()
# We cannot calculate coefs yet because we do not know input sample rate
coefs = None
filter_gen = filtergen(axis, coefs, coef_type) # Passthrough.
while True:
axis_arr_in = yield axis_arr_out
if coefs is None and order > 0:
fs = 1 / axis_arr_in.axes[axis or axis_arr_in.dims[0]].gain
coefs = scipy.signal.butter(
order, Wn=cutoffs, btype=btype, fs=fs, output=coef_type
)
filter_gen = filtergen(axis, coefs, coef_type)
axis_arr_out = filter_gen.send(axis_arr_in)
And I have unit tests to go with these.
You're amazing, Chad. I'm not fond of our filter inheritance structure but this refactor will do away with it entirely.
After having used this successfully a few times now, I'm going to deprecate our current filters and add these in as a new implementation. This will allow us to get rid of fs and make the changes we need in a backward compatible way.
By coincidence I was working on this today but I'm happy to move onto the Spectrogram and Bandpower rework.
Here is my test_butter.py
. Feel free to change the folder layout however you see fit.
# from typing import Optional
import numpy as np
import pytest
import scipy.signal
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.sigproc.butterworthfilter import (
ButterworthFilterSettings as LegacyButterSettings,
)
from ezmsg.sigproc.butter import butter, ButterworthFilterSettings
@pytest.mark.parametrize(
"cutoff, cuton",
[
(30.0, None), # lowpass
(None, 30.0), # highpass
(45.0, 30.0), # bandpass
(30.0, 45.0), # bandstop
],
)
@pytest.mark.parametrize("order", [2, 4, 8])
def test_butterworth_legacy_filter_settings(cutoff: float, cuton: float, order: int):
"""
Test the butterworth legacy filter settings generation of btype and Wn.
We test them explicitly because we assume they are correct when used in our later settings.
Parameters:
cutoff (float): The cutoff frequency for the filter. Can be None for highpass filters.
cuton (float): The cuton frequency for the filter. Can be None for lowpass filters.
If cuton is larger than cutoff we assume bandstop.
order (int): The order of the filter.
"""
settings_obj = LegacyButterSettings(
axis="time", fs=500, order=order, cuton=cuton, cutoff=cutoff
)
btype, Wn = settings_obj.filter_specs()
if cuton is None:
assert btype == "lowpass"
assert Wn == cutoff
elif cutoff is None:
assert btype == "highpass"
assert Wn == cuton
elif cuton <= cutoff:
assert btype == "bandpass"
assert Wn == (cuton, cutoff)
else:
assert btype == "bandstop"
assert Wn == (cutoff, cuton)
@pytest.mark.parametrize(
"cutoff, cuton",
[
(30.0, None), # lowpass
(None, 30.0), # highpass
(45.0, 30.0), # bandpass
(30.0, 45.0), # bandstop
],
)
@pytest.mark.parametrize("order", [0, 2, 5, 8]) # 0 = passthrough
# All fs entries must be greater than 2x the largest of cutoff | cuton
@pytest.mark.parametrize("fs", [200.0])
@pytest.mark.parametrize("n_chans", [3])
@pytest.mark.parametrize("n_dims, time_ax", [(1, 0), (3, 0), (3, 1), (3, 2)])
@pytest.mark.parametrize("coef_type", ["ba", "sos"])
def test_butterworth(
cutoff: float,
cuton: float,
order: int,
fs: float,
n_chans: int,
n_dims: int,
time_ax: int,
coef_type: str,
):
dur = 2.0
n_freqs = 5
n_splits = 4
n_times = int(dur * fs)
if n_dims == 1:
dat_shape = [n_times]
dat_dims = ["time"]
other_axes = {}
else:
dat_shape = [n_freqs, n_chans]
dat_shape.insert(time_ax, n_times)
dat_dims = ["freq", "ch"]
dat_dims.insert(time_ax, "time")
other_axes = {"freq": AxisArray.Axis(unit="Hz"), "ch": AxisArray.Axis()}
in_dat = np.arange(np.prod(dat_shape), dtype=float).reshape(*dat_shape)
# Calculate Expected Result
btype, Wn = LegacyButterSettings(
axis="time", fs=500, order=order, cuton=cuton, cutoff=cutoff
).filter_specs()
coefs = scipy.signal.butter(order, Wn, btype=btype, output=coef_type, fs=fs)
tmp_dat = np.moveaxis(in_dat, time_ax, -1)
if coef_type == "ba":
if order == 0:
# butter does not return correct coefs under these conditions; Set manually.
coefs = (np.array([1.0, 0.0]),) * 2
zi = scipy.signal.lfilter_zi(*coefs)
if n_dims == 3:
zi = np.tile(zi[None, None, :], (n_freqs, n_chans, 1))
out_dat, _ = scipy.signal.lfilter(*coefs, tmp_dat, zi=zi)
elif coef_type == "sos":
zi = scipy.signal.sosfilt_zi(coefs)
if n_dims == 3:
zi = np.tile(zi[:, None, None, :], (1, n_freqs, n_chans, 1))
out_dat, _ = scipy.signal.sosfilt(coefs, tmp_dat, zi=zi)
out_dat = np.moveaxis(out_dat, -1, time_ax)
# Split the data into multiple messages
n_seen = 0
messages = []
for split_dat in np.array_split(in_dat, n_splits, axis=time_ax):
_time_axis = AxisArray.Axis.TimeAxis(fs=fs, offset=n_seen / fs)
messages.append(
AxisArray(split_dat, dims=dat_dims, axes={**other_axes, "time": _time_axis})
)
n_seen += split_dat.shape[time_ax]
# Test axis_name `None` when target axis idx is 0.
axis_name = "time" if time_ax != 0 else None
gen = butter(
axis=axis_name,
order=order,
cuton=cuton,
cutoff=cutoff,
coef_type=coef_type,
)
result = np.concatenate([gen.send(_).data for _ in messages], axis=time_ax)
assert np.allclose(result, out_dat)
Oh, nice! had you made significant progress toward this change already? I was just about to start trying to work it in, but if you have collected thoughts already I won't get in the way
Not significant, no. And to be honest I was struggling a bit because I don't know what other apps out there you have that depend on the existing API. I'd much prefer if you took it over.
Here's one other snippet that might be helpful. I was using it at the top of filter
generator with the expectation that I will reuse it. I'm not using this snippet anywhere so take it or leave it in the reimplementation as you see fit.
def _normalize_coefs(
coefs: typing.Union[FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray],npt.NDArray]
) -> typing.Tuple[str, typing.Tuple[npt.NDArray,...]]:
coef_type = "ba"
if coefs is not None:
# scipy.signal functions called with first arg `*coefs`.
# Make sure we have a tuple of coefficients.
if isinstance(coefs, npt.NDArray):
coef_type = "sos"
coefs = (coefs,) # sos funcs just want a single ndarray.
elif isinstance(coefs, FilterCoefficients):
coefs = (FilterCoefficients.b, FilterCoefficients.a)
return coef_type, coefs
Awesome, I'll massage this into sigproc
For lack of time, I've simply dropped these functions into ezmsg.sigproc.filter, but I'll be reworking/deprecating the old API soon(ish)
I'm working on a generator-based refactor of filter and butterworth filter. I'll attach the WIP functions in a reply to this top-post.
In trying to redo the existing classes using the generator function, I've come across a couple things that I don't understand and/or I think can be modified to simplify and generalize.
First, the class
FilterSettingsBase
includes the fieldfs
. This isn't used as a "setting" at this base class level. While the child classButterworthFilterSettings
needsfs
to design the filter (because cuton and cutoff are provided in Hz), the parent does not usefs
except to check that it is not None before callingupdate_filter
. This is misplaced because a child of this base class might want to initiate a filter with the coefficients already known and fs is irrelevant at design time. Indeed,FilterSettings
has the fieldfilt
to store the coefficients.fs
fromFilterSettingsBase
and add it toButterworthFilterSettings
?Second, this block stalls out the pipeline if the filter is not designed:
I agree with the
TODO
that it should be a passthrough if there is no filter yet designed. IMO, it should be required that all the details necessary to design the filter should be provided on init + and possibly with the first non-empty input. If this requirement is enforced, then I expect it'll probably simplify some of the other workflows.INPUT_FILTER
connection is used so I can see why we might not want it designed at init?Third,
Filter
'sdesign_filter
implementation simply raisesNotImplementedError
. In my experience, having a method that raises that error is enough for linters to say "this is an abstract base class". However,Decimate
instantiates aFilter()
(manual b,a from a cheby1 design if factor > 1).I think it would be better if
Filter
were truly abstract and instead there were aSimpleFilter
child class that overrodedesign_filter
withreturn None
, andDecimate
usedSimpleFilter
instead.Fourth
I prefer "sos" filters.
FilterCoefficients
assumes "ba". What ifFilterCoefficients
' only field wascoefs
, with typeUnion[Tuple[np.ndarray], BACoeffs]
, whereBACoeffs
was the current version ofFilterCoefficients
? I think I'm missing aTuple
in there somewhere. But if done right, then it should be possible to useFilterCoefficients
for either "ba" or "sos".I think it's reasonable to do this in baby steps. I think there's a path to refactoring the classes before bringing in the generator functions. Unfortunately, I need the generator functions now to demonstrate offline and online harmony, but I can keep those in my project's namespace package for now.