dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
77 stars 28 forks source link

visualization steps fail after applying to new data #180

Open mshallow opened 3 days ago

mshallow commented 3 days ago

After training the model on a small portion of the data (a couple hundred thousand frames as suggested), and then applied the model to the rest of the dataset. After doing this, generate_trajectory_plots, generate_grive_movies and plot_similarity_dendrogram all fail in the same way. If I run these functions on the full set of coordinates and the results pulled from the results.h5 file, they all give a "KeyError: 'syllable'" even though the results file contains the key 'syllable'. Attached is an example of what the results file looks like, the code run and the errors.

For example, for the trajectory plots: model_name='2024_11_14-09_28_57' results = kpms.load_results(project_dir, model_name) kpms.generate_trajectory_plots(coordinates, results, project_dir, model_name, **config()) Saving trajectory plots to /Volumes/Projects/PreyCapture/ZIActivation/kpms/2024_11_14-09_28_57/trajectory_plots `--------------------------------------------------------------------------- KeyError Traceback (most recent call last) Cell In[8], line 3 1 model_name='2024_11_14-09_28_57' 2 results = kpms.load_results(project_dir, model_name) ----> 3 kpms.generate_trajectory_plots(coordinates, results, project_dir, model_name, **config())

File /opt/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py:1726, in generate_trajectory_plots(coordinates, results, project_dir, model_name, output_dir, pre, post, min_frequency, min_duration, skeleton, bodyparts, use_bodyparts, keypoint_colormap, plot_options, get_limits_pctl, padding, lims, save_individually, save_gifs, save_mp4s, fps, projection_planes, interactive, density_sample, sampling_options, **kwargs) 1723 os.makedirs(output_dir) 1724 print(f"Saving trajectory plots to {output_dir}") -> 1726 typical_trajectories = get_typical_trajectories( 1727 coordinates, 1728 results, 1729 pre, 1730 post, 1731 min_frequency, 1732 min_duration, 1733 bodyparts, 1734 use_bodyparts, 1735 density_sample, 1736 sampling_options, 1737 ) 1739 syllable_ixs = sorted(typical_trajectories.keys()) 1740 titles = [f"Syllable{s}" for s in syllable_ixs]

File /opt/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/util.py:1054, in get_typical_trajectories(coordinates, results, pre, post, min_frequency, min_duration, bodyparts, use_bodyparts, density_sample, sampling_options) 1051 if bodyparts is not None and use_bodyparts is not None: 1052 coordinates = reindex_by_bodyparts(coordinates, bodyparts, use_bodyparts) -> 1054 syllables = {k: v["syllable"] for k, v in results.items()} 1055 centroids = {k: v["centroid"] for k, v in results.items()} 1056 headings = {k: v["heading"] for k, v in results.items()}

File /opt/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/util.py:1054, in (.0) 1051 if bodyparts is not None and use_bodyparts is not None: 1052 coordinates = reindex_by_bodyparts(coordinates, bodyparts, use_bodyparts) -> 1054 syllables = {k: v["syllable"] for k, v in results.items()} 1055 centroids = {k: v["centroid"] for k, v in results.items()} 1056 headings = {k: v["heading"] for k, v in results.items()}

KeyError: 'syllable'` The first entry in the results.h5 file looks like this:

{'Sky_mouse-0893_2022-07-11T08_28_13DLC_dlcrnetms5_optopreycapFeb16shuffle1_150000_el_mouse': {'centroid': array([[662.46379994, 918.74729083], [662.4360982 , 918.74458729], [662.45538483, 918.59319527], ..., [423.89075159, 810.7915252 ], [423.87865671, 810.87515602], [424.06456624, 810.84802222]]), 'heading': array([ 8.46206572e-02, -1.88801896e-02, 2.44848130e-01, ..., 2.00854684e+01, 1.99765343e+01, 1.98199775e+01]), 'latent_state': array([[ 16.67395697, 4.29372412, -2.36620644, -3.11100955], [ 33.6296708 , 1.76510193, -6.2437582 , 6.01336587], [ 44.14367765, 13.04365683, -15.10881315, 5.05403441], ..., [ -0.09015983, 0.28800721, -0.33345008, 0.47476777], [ -0.1354536 , 0.41153289, -0.249936 , 0.30653053], [ -0.14849604, 0.475582 , -0.12138603, 0.25759183]]), 'syllable': array([41, 41, 41, ..., 73, 73, 73])}, Screenshot 2024-11-14 at 4 44 12 PM The only possible reason is that syllable is in the dictionary for each individual mouse, not the main dictionary that is a dictionary of dictionaries. The errors for the grid movies and similarity dendrogram are the same key error as in the trajectory plots.

calebweinreb commented 2 days ago

Hmm that's weird. Can you check all the keys and list the ones where syllable is missing? e.g.

for k,v in results.keys():
    if not 'syllable' in v:
        print(k)

Are the printed keys from the original batch or the new batch? Is there any pattern?

mshallow commented 2 days ago

I just tried those lines, and it gave a value error.

Screenshot 2024-11-15 at 12 51 18 PM
calebweinreb commented 2 days ago

Ah sorry should be results.items()

mshallow commented 2 days ago

there is a single file name that is printed using that. It's from the newly added data not the original training dataset

mshallow commented 2 days ago

For some reason, all that file has is centroids and headings but no syllables or latent states

Screenshot 2024-11-15 at 1 11 01 PM

but if I open the csv that was saved for that file, there are syllables and latent states

Screenshot 2024-11-15 at 1 12 33 PM
calebweinreb commented 2 days ago

Huh that's really weird. You could load the syllables from the csv, add them to the results, and then save the results (using kpms.save_hdf5 or just reapply the model to that one session. Hopefully that will fix the issue? Seems like its a random fluke as opposed to a systematic issue.

mshallow commented 2 days ago

I'll start by trying to reapply the model just to that session and see if that fixes it