dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
78 stars 28 forks source link

get_duration(z, mask) encountering IndexError when running kappa scan but not when trainig AR-HMM or full model #174

Open amorsi1 opened 1 month ago

amorsi1 commented 1 month ago

Hi! I'm encountering a problem during my kappa scan when it tries to visualize median duration of syllables.

 ---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[15], line 17
     15 # stage 1: fit the model with AR only
     16 model = kpms.update_hypparams(model, kappa=kappa)
---> 17 model = kpms.fit_model(
     18     model,
     19     data,
     20     metadata,
     21     project_dir,
     22     model_name,
     23     ar_only=True,
     24     num_iters=num_ar_iters,
     25     save_every_n_iters=25,
     26     parallel_message_passing=False
     27 )[0];
     29 # stage 2: fit the full model
     30 model = kpms.update_hypparams(model, kappa=kappa/decrease_kappa_factor)

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py:272](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py#line=271), in fit_model(model, data, metadata, project_dir, model_name, num_iters, start_iter, verbose, ar_only, parallel_message_passing, jitter, generate_progress_plots, save_every_n_iters, location_aware, **kwargs)
    270                 save_hdf5(checkpoint_path, model, f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){iteration}")
    271                 if generate_progress_plots:
--> 272                     plot_progress(
    273                         model,
    274                         data,
    275                         checkpoint_path,
    276                         iteration,
    277                         project_dir,
    278                         model_name,
    279                         savefig=True,
    280                     )
    282 return model, model_name

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py:620](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/viz.py#line=619), in plot_progress(model, data, checkpoint_path, iteration, project_dir, model_name, path, savefig, fig_size, window_size, min_frequency, min_histogram_length)
    618         z = np.array(f[f"model_snapshots[/](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/own_files_keypoint_moseq_colab.ipynb){i}/states/z"])
    619         sample_state_history.append(z[batch_ix, start : start + window_size])
--> 620         median_durations.append(np.median(get_durations(z, mask)))
    622 axs[2].scatter(saved_iterations, median_durations)
    623 axs[2].set_ylim([-1, np.max(median_durations) * 1.1])

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:82](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=81), in get_durations(stateseqs, mask)
     80 print(mask)
     81 #AM edits
---> 82 stateseq_flat = concatenate_stateseqs(stateseqs, mask=mask).astype(int)
     83 stateseq_padded = np.hstack([[-1], stateseq_flat, [-1]])
     84 changepoints = np.diff(stateseq_padded).nonzero()[0]

File [~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:40](http://localhost:8889/lab/workspaces/auto-s/tree/mnt/md0/dev/keypoint-moseq/notebooks/~/miniforge3/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py#line=39), in concatenate_stateseqs(stateseqs, mask)
     38     stateseq_flat = np.hstack(stateseqs)
     39 elif mask is not None:
---> 40     stateseq_flat = stateseqs[mask[:, -stateseqs.shape[1] :] > 0]
     41 else:
     42     stateseq_flat = stateseqs.flatten()

IndexError: boolean index did not match indexed array along dimension 0; dimension is 102 but corresponding boolean dimension is 116

This is on keypoint_moseq version 0.4.10 and jax_moseq version 0.2.2. I put some print statements into the jax_moseq utils get_durations function (since debugging on notebooks can be cumbersome) and let them crash out to find that the stateseqs and mask shapes consistently looked like this:

stateseqs len (116, 10027)
[[ 3  3  3 ... 79 79 46]
 [95 95  5 ... 90 90 90]
 [90 90 90 ... 58 58 58]
 ...
 [90 90 90 ... 54 54 54]
 [36 36 36 ... 90 90 90]
 [90 90 90 ... 67 90 90]]
mask len (116, 10030)
[[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]]
stateseqs len (102, 10027)
[[72 72 96 ... 91 91 91]
 [72 72 72 ... 52 52 52]
 [52 52 52 ... 23 23 23]
 ...
 [23 23 23 ... 77 77 77]
 [23 23 23 ... 12 12 12]
 [12 12 12 ... 87 87 87]]
mask len (116, 10030)
[[1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 0. 0.]]

Keep in mind that only the first dimension of the shape is contributing to this mismatch, as this function is able to handle mismatches in the other dimensions. I don't encounter this issue when training the AR-HMM or the full model, so I just ran the AR-HMM cell a few times with different kappa values to decide on a good one then trained the full model. Wanted to open this issue in case anyone else runs into it or if Caleb has any thoughts on why this is happening.