dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
63 stars 25 forks source link

Problems with Initializing Model #139

Closed afbrokaw closed 4 months ago

afbrokaw commented 4 months ago

Hi there, I am trying to run Keypoint-MoSeq on my own set of data collected from DeepLabCut. I am using a computing node through the University of Colorado, running JupyterLab in a keypoint-moseq environment.

After using a subset of my own data as personal demo and successfully fitting a model, I am running into problems when I get to the Initialize model stage on a new set of data. I have also tried to replicate my model fitting success on the original set of data I used, and encountered the same error (with no changes to the original datafiles).

The error is a ValueError: Size of label 'j' for operand 1 (31) does not match previous terms (19). Here also is the full traceback: ` Traceback (most recent call last) Cell In[6], line 2 1 # initialize the model ----> 2 model = kpms.init_model(data, pca=pca, **config())

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/fitting.py:104, in init_model(location_aware, allo_hypparams, trans_hypparams, *args, kwargs) 97 return allo_keypoint_slds.init_model( 98 *args, 99 allo_hypparams=allo_hypparams, 100 trans_hypparams=trans_hypparams, 101 *kwargs, 102 ) 103 else: --> 104 return keypoint_slds.init_model( 105 args, trans_hypparams=trans_hypparams, kwargs 106 )

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/models/keypoint_slds/initialize.py:332, in init_model(data, states, params, hypparams, noise_prior, seed, pca, whiten, PCA_fitting_num_frames, anterior_idxs, posterior_idxs, conf_threshold, error_estimator, trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams, verbose, exclude_outliers_for_pca, fix_heading, **kwargs) 330 print("Keypoint SLDS: Initializing states") 331 obs_hypparams = hypparams["obs_hypparams"] --> 332 states = init_states( 333 seed, 334 Y, 335 mask, 336 params, 337 noise_prior, 338 obs_hypparams, 339 Y_flat, 340 v, 341 h, 342 fix_heading, 343 ) 344 else: 345 states = jax.device_put(states)

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/models/keypoint_slds/initialize.py:66, in init_states(seed, Y, mask, params, noise_prior, obs_hypparams, Y_flat, v, h, fix_heading, kwargs) 63 Y_flat, v, h = preprocess_for_pca(Y, fix_heading, kwargs) 65 x = slds.init_continuous_stateseqs(Y_flat, params["Cd"]) ---> 66 states = arhmm.init_states(seed, x, mask, params) 68 states["x"] = x 69 states["v"] = v

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/models/arhmm/initialize.py:72, in init_states(seed, x, mask, params, kwargs) 49 def init_states(seed, x, mask, params, kwargs): 50 """ 51 Initialize the latent states of the ARHMM from the 52 data and parameters. (...) 70 State values for each latent variable. 71 """ ---> 72 z = resample_discrete_stateseqs(seed, x, mask, **params) 73 return {"z": z}

[... skipping hidden 14 frame]

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/models/arhmm/gibbs.py:55, in resample_discrete_stateseqs(seed, x, mask, Ab, Q, pi, **kwargs) 29 """ 30 Resamples the discrete state sequence z. 31 (...) 52 Discrete state sequences. 53 """ 54 nlags = get_nlags(Ab) ---> 55 log_likelihoods = jax.lax.map(partial(ar_loglikelihood, x), (Ab, Q)) 56 , z = jax.vmap(sample_hmm_stateseq, in_axes=(0, na, 0, 0))( 57 jr.split(seed, mask.shape[0]), 58 pi, 59 jnp.moveaxis(log_likelihoods, 0, -1), 60 mask.astype(float)[:, nlags:], 61 ) 62 return z

[... skipping hidden 12 frame]

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/autoregression.py:22, in ar_log_likelihood(x, params) 20 Ab, Q = params 21 nlags = get_nlags(Ab) ---> 22 mu = apply_ar_params(x, Ab) 23 x = x[..., nlags:, :] 24 return tfd.MultivariateNormalFullCovariance(mu, Q).log_prob(x)

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/autoregression.py:16, in apply_ar_params(x, Ab) 14 nlags = get_nlags(Ab) 15 x_in = get_lags(x, nlags) ---> 16 return apply_affine(x_in, Ab)

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/utils.py:199, in apply_affine(x, Ab) 198 def apply_affine(x, Ab): --> 199 return jnp.einsum("...ij, ...j->...i", Ab, pad_affine(x))

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:2923, in einsum(out, optimize, precision, _use_xeinsum, operands) 2921 ty = next(iter(non_constant_dim_types)) 2922 contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) -> 2923 operands, contractions = contract_path( 2924 operands, einsum_call=True, useblas=True, optimize=optimize) 2926 contractions = tuple((a, frozenset(b), c) for a, b, c, * in contractions) 2928 _einsum_computation = jax.named_call( 2929 _einsum, name=spec) if spec is not None else _einsum

File /projects/albr4711/software/anaconda/envs/keypoint_moseq/lib/python3.9/site-packages/opt_einsum/contract.py:238, in contract_path(*operands, **kwargs) 236 size_dict[char] = dim 237 elif dim not in (1, size_dict[char]): --> 238 raise ValueError("Size of label '{}' for operand {} ({}) does not match previous " 239 "terms ({}).".format(char, tnum, size_dict[char], dim)) 240 else: 241 size_dict[char] = dim

ValueError: Size of label 'j' for operand 1 (31) does not match previous terms (19).`

I have limited experience working in Python, and have tried to run the tutorial code on a few different sets of data, including rerunning on data I successfully modeled previously and continue to get the same error. Any advice would be greatly appreciated!

Cheers, Alyson

afbrokaw commented 4 months ago

Just saw that there was a closed issue that addressed a similar problem, but did not actually share a final solution. Here is some additional information from my scenario that was requested in the previous issue. I will specify that I installed/starting using Keypoint Moseq just within the last week (installed around Feb 26th, 2024).

import keypoint_moseq, jax_moseq print(keypoint_moseq.__version__, jax_moseq.__version__) 0.4.4 0.2.1

print(kpms.load_config(project_dir)) {'bodyparts': ['odorspot', 'controlspot', 'nose', 'neck', 'bodycenter', 'tailbase'], 'use_bodyparts': ['nose', 'neck', 'bodycenter', 'tailbase'], 'skeleton': [['nose', 'neck'], ['neck', 'bodycenter'], ['bodycenter', 'tailbase']], 'anterior_bodyparts': ['nose'], 'posterior_bodyparts': ['tailbase'], 'added_noise_level': 0.1, 'PCA_fitting_num_frames': 1000000, 'conf_threshold': 0.5, 'error_estimator': {'intercept': 0.25, 'slope': -0.5}, 'obs_hypparams': {'nu_s': 5, 'nu_sigma': 100000.0, 'sigmasq_0': 0.1, 'sigmasq_C': 0.1}, 'ar_hypparams': {'K_0_scale': 10.0, 'S_0_scale': 0.01, 'latent_dim': 10, 'nlags': 3}, 'trans_hypparams': {'alpha': 5.7, 'gamma': 1000.0, 'kappa': 1000000.0, 'num_states': 100}, 'cen_hypparams': {'sigmasq_loc': 0.5}, 'recording_name_suffix': '', 'verbose': False, 'conf_pseudocount': 0.001, 'video_dir': '/projects/albr4711/Keypoint_Oct17/videos/', 'keypoint_colormap': 'autumn', 'whiten': True, 'fix_heading': False, 'seg_length': 10000, 'anterior_idxs': DeviceArray([0], dtype=int64), 'posterior_idxs': DeviceArray([3], dtype=int64)}

calebweinreb commented 4 months ago

Hi,

The problem is probably that latent_dim=10 in your config, but you are only using 4 bodyparts (quite a small number, by the way), so the pca object has max dimension 8. If you set latent_dim to <= 8 then I think this error will go away. We'll add a more interpretable error message in the next release.

afbrokaw commented 4 months ago

Okay, thanks! I will try that and appreciate the insight. And yes, I'm aware that at least 5 body parts are recommended, as I said, this is just me using some data I have as my own personal demo for the purposes of exploring the software and data processing pipelines.