neuroinformatics-unit / movement

Python tools for analysing body movements across space and time
http://movement.neuroinformatics.dev
BSD 3-Clause "New" or "Revised" License
77 stars 7 forks source link

Implement Median and Savitzky-Golay Filters #163

Closed b-peri closed 1 month ago

b-peri commented 2 months ago

Edited 08/05/2024:

This PR introduces two new smoothing functions to the filtering module:

  1. median_filter(ds, window_length): Smooths pose tracks in the input dataset by applying a median filter along the time dimension. Window length must be specified by the user, and is interpreted as being in the input dataset's time unit (usually seconds). The filter window is centered over the filter origin.
  2. savgol_filter(ds, window_length, polyorder, **kwargs): Smooths pose tracks over time using a Savitzky-Golay filter. Again, window length must be specified by the user, and is interpreted as being in the input dataset's time unit. The order of the polynomial used to fit the samples can optionally be specified by the user. If omitted, a default value of 2 is used. Additional keyword arguments (**kwargs) are passed to scipy.signal.savgol_filter() directly, but note that the axis kwarg may not be overwritten.

What is this PR

References

Closes #55, closes #139

Checklist:

b-peri commented 2 months ago

This PR was a bit tricky and some of the defaults currently implemented are a bit arbitrary. Here's some of the reasoning behind a few of the decisions I've made thus far, as well as some things I'd like a bit of feedback on:

Median Filter

In issue #55, @niksirbi suggested three potential implementations for the median filter:

  1. rolling(time=window_length).median()
  2. scipy.signal.medfilt()
  3. scipy.ndimage.median__filter()

While rolling(time=window_length).median() would be a straightforward implementation, it fixes the filter origin at the very end of the filter window, and converts the first (window_length - 1) points on the input array into NaNs. Conversely, scipy.signal.medfilt() does not allow us to specify the particular axis along which to apply the filter, and nor does it enable us to determine the position of the filter origin with respect to the window. In contrast, scipy.ndimage.median_filter() runs more quickly than scipy.signal.medfilt() (see notes at bottom of page here, does not necessarily convert the first entries of the array into NaNs, and allows us to flexibly modify the position of the filter origin, extension mode, and axis along which to apply the filter.

The flexibility of this function, however, also means that we must now decide on how many of these parameters to expose to the users, and what defaults to use for these, irrespective of whether they are user-facing or not. Here, I'm specifically worried about the mode, cval, and origin arguments, which I am struggling a bit to understand more generally (at least in terms of the merits of e.g. one extension mode vs. the other), and have now selected relatively arbitrarily. If anyone has any particular feedback regarding what defaults would be sensible to use here, or whether or not these should be left up to the user, I would be very happy to hear it!

Savitzky-Golay Filter

While our choice of implementations was more constrained for the Savitzky-Golay filter, again I am still trying to figure out what some sensible default values might be - specifically for the polyorder, mode, and cval arguments. I'll do some more research about what is most common/useful for our application myself but again, if anyone has particular thoughts on this I'd love to hear them!

Diagnostics

I assume it would be useful to report a number of diagnostics after each of these functions has been run, but I'm not quite sure where to start here. Are there some standard/straightforward diagnostics that I could implement here (perhaps something showing local variability along the timeseries)? Knowing this will also help design tests.

Separate vs. Single Smoothing Function

I've now implemented the different smoothing methods as separate functions, but alternatively, we could consider wrapping these all in a single, generic smooth() function, where method is passed as an argument (similar to DeepLabCut's implementation). This may mean that we have to deal with quite a few potential arguments in the same function, but given that many of these overlap across methods (e.g. window_length, mode, etc.), I think this might be manageable. I'd appreciate hearing your thoughts on this!

niksirbi commented 2 months ago

Thanks for the work and the comprehensive writeup @b-peri 🤗 . I'll take a look at this next week and get back to you on the specific points you raised.

niksirbi commented 2 months ago

Some additional thoughts on this, based on today's meeting:

codecov[bot] commented 1 month ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 99.68%. Comparing base (c7c07a6) to head (51cb39d). Report is 1 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #163 +/- ## ========================================== + Coverage 99.66% 99.68% +0.02% ========================================== Files 10 11 +1 Lines 591 637 +46 ========================================== + Hits 589 635 +46 Misses 2 2 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

b-peri commented 1 month ago

Thanks very much for the extensive feedback @niksirbi! I've now implemented most of your points. Following on from our discussions here and on Zulip, I just had the following notes:

  1. For both filters, the user must now manually specify the window length. This value is then interpreted dynamically according to the time unit of the input dataset (e.g. if ds.time_unit == 'seconds', window_length will be interpreted in seconds). Because of this flexibility, we've also decided to forgo implementing a rigid default value.
  2. After doing some testing on NaN-handling, we've discovered the following:
    • While scipy.signal.savgol_filter() (the function upon which Movement's savgol_filter() is based) is not robust to NaNs, it does behave in a consistent way - if there are any NaNs in the window at a given point, a NaN is returned to the output array. Consequently, existing NaNs in the input array are propagated by the filter in a way that is proportional to the size of the window.
    • Conversely, scipy.ndimage.median_filter() produces broadly undefined behavior for NaNs (For a more extensive discussion, see here). As such, I've changed the implementation of median_filter() from scipy.ndimage.median_filter() to the xarray built-in xr.Dataset.rolling(time=window_length, center=True).median(skipna=True). While this also isn't robust to NaNs in the input dataset, it behaves more similarly to scipy.signal.savgol_filter() - i.e. NaNs are propagated whenever a NaN is in the window. A notable difference here is that rolling().median() also introduces new NaNs at either edge of the input array. Especially when window sizes are large, this can therefore be quite destructive, so perhaps a different (custom) implementation for both filters should be explored later down the line (e.g. see this solution).
    • I've detailed the behavior of each of the above filters under the Notes section of their respective docstring.
      1. As per your suggestions, savgol_filter() now exposes only the ds, window_length, and polyorder arguments. Any additional keyword arguments entered are then passed directly to scipy.signal.savgol_filter(), but attempting to overwrite the axis keyword causes the function to throw an error. For the polyorder argument, I've implemented a default value of 2, as this seemed to be the most common polyorder used in other pose tracking pipelines (Zeng et al., 2021; Gonzalez et al., 2021; https://github.com/DurhamARC/raga-pose-estimation). There was quite some variability in the literature here though (see, for example Matsuda et al., 2023; Hebert et al., 2021), so I'm open to revisiting this.
  3. Basic testing for both new functions has now also been implemented.
b-peri commented 1 month ago

Hey @niksirbi, thanks again for the very in-depth feedback! I've implemented pretty much all of your suggestions as-is. Some points of note are the following:

Count NaNs

I've solved the duplication issue with count_nans() by implementing a new Helper() class in conftest.py, under which count_nans() is a static method. Alongside this I've added a helpers() fixture, which enables us to pass Helpers to individual tests without requiring an explicit import call at the top of the script.

class Helpers:
    """Generic helper methods for ``movement`` testing modules."""

    @staticmethod
    def count_nans(ds):
        """Count NaNs in the x coordinate timeseries of the first keypoint
        of the first individual in the dataset.
        """
        n_nans = np.count_nonzero(
            np.isnan(
                ds.position.isel(individuals=0, keypoints=0, space=0).values
            )
        )
        return n_nans

@pytest.fixture
def helpers():
    """Return an instance of the ``Helpers`` class."""
    return Helpers

Integration Tests

Whenever NaNs occur in continuous stretches (as they often do in periods of occlusion of poor visibility of the keypoint), that whole block of NaNs is essentially only propagated once, by c. (window_length//2) on either end of the stretch. This means that computing the upper bound using the absolute number of NaNs - as we do now - may be overshooting it to the extent that the test is no longer sufficiently strict to be meaningful.

To deal with this, I've therefore tweaked test_nan_propagation_through_filters() to calculate the max_nans_increase based on the number of consecutive stretches of NaNs occuring in the input dataset (e.g. max_nans_increase = (window_length - 1) * nan_repeats_after_filt).

I've also added a new, generic Helper() method to automate counting these stretches of NaNs.

    def count_nan_repeats(ds):
        """Count the number of NaN repeats in the x coordinate timeseries
        of the first keypoint of the first individual in the dataset.
        """
        x = ds.position.isel(individuals=0, keypoints=0, space=0).values
        repeats = []
        running_count = 1
        for i in range(len(x)):
            if i != len(x) - 1:
                if np.isnan(x[i]) and np.isnan(x[i + 1]):
                    running_count += 1
                elif np.isnan(x[i]):
                    repeats.append(running_count)
                    running_count = 1
                else:
                    running_count = 1
            elif np.isnan(x[i]):
                repeats.append(running_count)
                running_count = 1
        return len(repeats)

This should be all for now! Thanks again!

sonarcloud[bot] commented 1 month ago

Quality Gate Passed Quality Gate passed

Issues
2 New issues
0 Accepted issues

Measures
0 Security Hotspots
No data about Coverage
0.0% Duplication on New Code

See analysis details on SonarCloud