Closed b-peri closed 1 month ago
This PR was a bit tricky and some of the defaults currently implemented are a bit arbitrary. Here's some of the reasoning behind a few of the decisions I've made thus far, as well as some things I'd like a bit of feedback on:
In issue #55, @niksirbi suggested three potential implementations for the median filter:
rolling(time=window_length).median()
scipy.signal.medfilt()
scipy.ndimage.median__filter()
While rolling(time=window_length).median()
would be a straightforward implementation, it fixes the filter origin at the very end of the filter window, and converts the first (window_length
- 1) points on the input array into NaNs. Conversely, scipy.signal.medfilt()
does not allow us to specify the particular axis along which to apply the filter, and nor does it enable us to determine the position of the filter origin with respect to the window. In contrast, scipy.ndimage.median_filter()
runs more quickly than scipy.signal.medfilt()
(see notes at bottom of page here, does not necessarily convert the first entries of the array into NaNs, and allows us to flexibly modify the position of the filter origin, extension mode, and axis along which to apply the filter.
The flexibility of this function, however, also means that we must now decide on how many of these parameters to expose to the users, and what defaults to use for these, irrespective of whether they are user-facing or not. Here, I'm specifically worried about the mode
, cval
, and origin
arguments, which I am struggling a bit to understand more generally (at least in terms of the merits of e.g. one extension mode vs. the other), and have now selected relatively arbitrarily. If anyone has any particular feedback regarding what defaults would be sensible to use here, or whether or not these should be left up to the user, I would be very happy to hear it!
While our choice of implementations was more constrained for the Savitzky-Golay filter, again I am still trying to figure out what some sensible default values might be - specifically for the polyorder
, mode
, and cval
arguments. I'll do some more research about what is most common/useful for our application myself but again, if anyone has particular thoughts on this I'd love to hear them!
I assume it would be useful to report a number of diagnostics after each of these functions has been run, but I'm not quite sure where to start here. Are there some standard/straightforward diagnostics that I could implement here (perhaps something showing local variability along the timeseries)? Knowing this will also help design tests.
I've now implemented the different smoothing methods as separate functions, but alternatively, we could consider wrapping these all in a single, generic smooth()
function, where method is passed as an argument (similar to DeepLabCut's implementation). This may mean that we have to deal with quite a few potential arguments in the same function, but given that many of these overlap across methods (e.g. window_length
, mode
, etc.), I think this might be manageable. I'd appreciate hearing your thoughts on this!
Thanks for the work and the comprehensive writeup @b-peri 🤗 . I'll take a look at this next week and get back to you on the specific points you raised.
Some additional thoughts on this, based on today's meeting:
smooth()
function idea for now. Once the module has grown a bit, we can think of ways to make the interface easier (e.g. by providing pipeline functionalities).All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 99.68%. Comparing base (
c7c07a6
) to head (51cb39d
). Report is 1 commits behind head on main.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Thanks very much for the extensive feedback @niksirbi! I've now implemented most of your points. Following on from our discussions here and on Zulip, I just had the following notes:
ds.time_unit == 'seconds'
, window_length
will be interpreted in seconds). Because of this flexibility, we've also decided to forgo implementing a rigid default value.scipy.signal.savgol_filter()
(the function upon which Movement's savgol_filter()
is based) is not robust to NaNs, it does behave in a consistent way - if there are any NaNs in the window at a given point, a NaN is returned to the output array. Consequently, existing NaNs in the input array are propagated by the filter in a way that is proportional to the size of the window.scipy.ndimage.median_filter()
produces broadly undefined behavior for NaNs (For a more extensive discussion, see here). As such, I've changed the implementation of median_filter()
from scipy.ndimage.median_filter()
to the xarray built-in xr.Dataset.rolling(time=window_length, center=True).median(skipna=True)
. While this also isn't robust to NaNs in the input dataset, it behaves more similarly to scipy.signal.savgol_filter()
- i.e. NaNs are propagated whenever a NaN is in the window. A notable difference here is that rolling().median()
also introduces new NaNs at either edge of the input array. Especially when window sizes are large, this can therefore be quite destructive, so perhaps a different (custom) implementation for both filters should be explored later down the line (e.g. see this solution).Notes
section of their respective docstring.
savgol_filter()
now exposes only the ds
, window_length
, and polyorder
arguments. Any additional keyword arguments entered are then passed directly to scipy.signal.savgol_filter()
, but attempting to overwrite the axis
keyword causes the function to throw an error. For the polyorder
argument, I've implemented a default value of 2
, as this seemed to be the most common polyorder used in other pose tracking pipelines (Zeng et al., 2021; Gonzalez et al., 2021; https://github.com/DurhamARC/raga-pose-estimation). There was quite some variability in the literature here though (see, for example Matsuda et al., 2023; Hebert et al., 2021), so I'm open to revisiting this.Hey @niksirbi, thanks again for the very in-depth feedback! I've implemented pretty much all of your suggestions as-is. Some points of note are the following:
I've solved the duplication issue with count_nans()
by implementing a new Helper()
class in conftest.py
, under which count_nans()
is a static method. Alongside this I've added a helpers()
fixture, which enables us to pass Helpers
to individual tests without requiring an explicit import call at the top of the script.
class Helpers:
"""Generic helper methods for ``movement`` testing modules."""
@staticmethod
def count_nans(ds):
"""Count NaNs in the x coordinate timeseries of the first keypoint
of the first individual in the dataset.
"""
n_nans = np.count_nonzero(
np.isnan(
ds.position.isel(individuals=0, keypoints=0, space=0).values
)
)
return n_nans
@pytest.fixture
def helpers():
"""Return an instance of the ``Helpers`` class."""
return Helpers
Whenever NaNs occur in continuous stretches (as they often do in periods of occlusion of poor visibility of the keypoint), that whole block of NaNs is essentially only propagated once, by c. (window_length//2)
on either end of the stretch. This means that computing the upper bound using the absolute number of NaNs - as we do now - may be overshooting it to the extent that the test is no longer sufficiently strict to be meaningful.
To deal with this, I've therefore tweaked test_nan_propagation_through_filters()
to calculate the max_nans_increase
based on the number of consecutive stretches of NaNs occuring in the input dataset (e.g. max_nans_increase = (window_length - 1) * nan_repeats_after_filt
).
I've also added a new, generic Helper()
method to automate counting these stretches of NaNs.
def count_nan_repeats(ds):
"""Count the number of NaN repeats in the x coordinate timeseries
of the first keypoint of the first individual in the dataset.
"""
x = ds.position.isel(individuals=0, keypoints=0, space=0).values
repeats = []
running_count = 1
for i in range(len(x)):
if i != len(x) - 1:
if np.isnan(x[i]) and np.isnan(x[i + 1]):
running_count += 1
elif np.isnan(x[i]):
repeats.append(running_count)
running_count = 1
else:
running_count = 1
elif np.isnan(x[i]):
repeats.append(running_count)
running_count = 1
return len(repeats)
This should be all for now! Thanks again!
Issues
2 New issues
0 Accepted issues
Measures
0 Security Hotspots
No data about Coverage
0.0% Duplication on New Code
Edited 08/05/2024:
This PR introduces two new smoothing functions to the
filtering
module:median_filter(ds, window_length)
: Smooths pose tracks in the input dataset by applying a median filter along the time dimension. Window length must be specified by the user, and is interpreted as being in the input dataset's time unit (usually seconds). The filter window is centered over the filter origin.savgol_filter(ds, window_length, polyorder, **kwargs)
: Smooths pose tracks over time using a Savitzky-Golay filter. Again, window length must be specified by the user, and is interpreted as being in the input dataset's time unit. The order of the polynomial used to fit the samples can optionally be specified by the user. If omitted, a default value of2
is used. Additional keyword arguments (**kwargs
) are passed toscipy.signal.savgol_filter()
directly, but note that theaxis
kwarg may not be overwritten.What is this PR
References
Closes #55, closes #139
Checklist: