talmolab / sleap

A deep learning framework for multi-animal pose tracking.
https://sleap.ai
Other
435 stars 97 forks source link

Allow returning PAF graph during low level inference #1329

Closed calebweinreb closed 1 year ago

calebweinreb commented 1 year ago

This PR creates the option to return the parsed PAF graph during low-level inference. Here's a minimal example:

import sleap

model_directory = "kp_moseq_bottomup_1230512_153003.multi_instance.n=358/"
video_name = "Session Labeled 2023-05-05T10 37 02_De3ob3_2M_2.mp4"

predictor = sleap.load_model(model_directory)
predictor.inference_model.bottomup_layer.return_paf_graph = True

video = sleap.load_video(video_name)
predictions = predictor.inference_model.predict(video[:10])

for k in ['peaks','peak_vals','peak_channel_inds','edge_inds','edge_peak_inds','line_scores']:
    print(predictions[k].shape, k)

# RETURNS
# (10, 16, 2) peaks
# (10, 16) peak_vals
# (10, 16) peak_channel_inds
# (10, 28) edge_inds
# (10, 28, 2) edge_peak_inds
# (10, 28) line_scores

I added an entry for return_paf_graph to the relevant docstrings. Currently the docs don't describe the contents of every single key, but I think it's pretty clear from the names.

codecov[bot] commented 1 year ago

Codecov Report

Merging #1329 (7cd9480) into develop (4525462) will decrease coverage by 0.06%. The diff coverage is 11.11%.

@@             Coverage Diff             @@
##           develop    #1329      +/-   ##
===========================================
- Coverage    72.66%   72.61%   -0.06%     
===========================================
  Files          133      133              
  Lines        23675    23692      +17     
===========================================
  Hits         17204    17204              
- Misses        6471     6488      +17     
Impacted Files Coverage Δ
sleap/nn/paf_grouping.py 91.55% <0.00%> (ø)
sleap/nn/inference.py 80.96% <12.50%> (-0.36%) :arrow_down:

... and 1 file with indirect coverage changes

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more