neuroinformatics-unit / movement

Python tools for analysing body movements across space and time
http://movement.neuroinformatics.dev
BSD 3-Clause "New" or "Revised" License
104 stars 8 forks source link

Implement LightningPoses losses for outlier detection #145

Open niksirbi opened 7 months ago

niksirbi commented 7 months ago

Context

We want to implement more functionalities for evaluating pose estimation model performance and for detecting outliers - i.e. erroneous/implausible pose predictions.

LightningPose currently uses a set of 3 spatiotemporal constraints as unsupervised losses during model training - i.e. the network is penalised if it violates these constraints. These 3 constraints are:

Much more details about those can be found in the paper. The implementation of these losses is in this module as far as I can tell.

What we want

We'd like to use these same constraints as post hoc outlier detection heuristics within movement, in addition to the confidence threshold approach we already use in filtering.py. In fact, most of these heuristics have been already used by others as outlier detection approaches (see references in the LightningPose paper), before being implemented as network losses by LightningPose

Discussions

This idea came out of a chat me, @sfmig and @lochhh had with @themattinthehatt (LigthningPose co-author and dev). We had initially documented this as a topic on Zulip.

Folks from the Allen Institute for Neural Dynamics are also interested in this, as mentioned by @Di-Wang-AIND on Zulip:

For evaluation part, we are more interested in performing quality control on the predictions outputted by single framework. If the prediction is unsatisfactory, the quality of the training data may need to improve (i.e., clean data, relabel frames) or more labeled frames are needed. Therefore, ability to detect outliers or locate the section of video with poor prediction would be great to have. The evaluation metrics introduced by Lightning-pose, like pixel error, temporal_norm and pca_singleview_error, may help to automatically filter out frames with poor predictions.

### Tasks
- [ ] https://github.com/neuroinformatics-unit/movement/issues/151
- [ ] Detect outliers using multi-view consistency heuristic
- [ ] https://github.com/neuroinformatics-unit/movement/issues/152
sfmig commented 4 months ago

Currently we see this mostly for outliers detection, but we may want to additionally use this as a quality metric for predicted poses in the absence of ground truth.

hummuscience commented 2 weeks ago

LightningPose automatically outputs a "PCA error" (pixelerror or the reprojection for each keypoint) and a "temporal error" which is the normalized temporal difference of all keypoints on consecutive frames.

Would it make sense to allow the use to load those as an option?

Or would you rather have an in-house implementation?

sfmig commented 2 weeks ago

Thanks for pointing this out @hummuscience!

We still need to play with LightningPose a bit more, but just to confirm: does this mean that these error metrics (PCA and temporal error) can be exported as part of the output file, for every predicted keypoint?

If so, what you suggest could be a very good idea. We could offer the user the option to import these as two additional data arrays in the loaded movement dataset.

In any case, both implementations seem compatible. We could allow users to import these metrics as data arrays, if they are exported to the LightningPose output file, but also we could have functionality for users to compute these metrics from predicted keypoints (computed with any pose estimation framework we support). The second one is what we were initially considering.

Is this a functionality that you would be interested in having for your analysis? I'm curious if people are using these LightningPose error metrics in posterior analyses.

niksirbi commented 2 weeks ago

Agreed with @sfmig here.

Eventually we do want to implement the computation of these metrics here, so that people who've done poses estimation with tools other than LightningPose can also measure them. But having the ability to load them from LightningPose output files could be a good intermediate step, because it will force us to think about how to represent these as DataArray objects.

hummuscience commented 2 weeks ago

Currently, the scripts to predict new videos and for training the models (https://github.com/paninski-lab/lightning-pose/blob/main/scripts/predict_new_vids.py, https://github.com/paninski-lab/lightning-pose/blob/main/scripts/train_hydra.py) can output a pca_error (if keypoints are chosen in the config file for the PCA) and a temporal_error file (if a temporal loss is chosen).

Both are arrays with time in rows and keypoints in columns.

The way I am currently adding them to a movement dataset is like so (I am new to xarrays/movement, so maybe there is a better way):

ds = load_poses.from_lp_file('../predictions.csv', fps=60)

temporal_error = pd.read_csv('../temporal_norm.csv', header=[0], index_col=[0])

pca_error = pd.read_csv('../pca_singleview_error.csv', header=[0], index_col=[0])

ds['temporal_error'] = xr.DataArray(
    temporal_error.values.reshape(tuple(ds.dims[d] for d in ['time', 'individuals','keypoints'])),
    dims=  ['time', 'individuals','keypoints']
            )
ds['pca_error'] = xr.DataArray(
    pca_error.values.reshape(tuple(ds.dims[d] for d in ['time', 'individuals','keypoints'])),
    dims=  ['time', 'individuals','keypoints']
            )

And yes, I am interested in this analysis myself as it can give more information on the quality of a prediction in addition to the confidence. Here is an example of the temporal error vs. PCA error (and the confidence in color) of a single keypoint of a tracked mouse.

Screenshot 2024-10-21 at 17 42 37

If one would go with a simple confidence cutoff (like 0.95, Lightning Pose people recommend >0.9) one would still keep a lot of frames that have large temporal or PCA error (in orange here).

Screenshot 2024-10-21 at 17 45 34

hummuscience commented 2 weeks ago

If I understand things corrrectly, calculating the temporal loss/error should be straight forward with the current tools available in movement (https://github.com/paninski-lab/lightning-pose/blob/84538bf76e1e2529006ab81fc03b86d9801334e6/lightning_pose/losses/losses.py#L360).

The PCA error needs more implementation I guess

niksirbi commented 2 weeks ago

Thanks a lot for all the info @hummuscience, this is all very helpful!

I double-checked, and your way of creating the DataArrays is exactly right. I also agree that having these metrics available in movement is great for quality control and outlier detection.

Regarding ways of moving forward:

Thanks again for your input and I'm glad you're finding movement useful!

niksirbi commented 2 weeks ago

FYI, they only tweak I'd make in your implementation is to use the more future-proof "sizes" as opposed to "dims":

error_shape = tuple(ds.sizes[d] for d in ["time", "individuals", "keypoints"])

ds["temporal_error"] = xr.DataArray(
    temporal_error.values.reshape(error_shape),
    dims=("time", "individuals", "keypoints"),
)

ds["pca_error"] = xr.DataArray(
    pca_error.values.reshape(error_shape),
    dims=("time", "individuals", "keypoints"),
)

The above snippet is from my first draft for a new example, to be found here.