dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
71 stars 28 forks source link

Error message during model initiation #101

Closed Aexolowski closed 10 months ago

Aexolowski commented 1 year ago

Hi there, I am running into this error message when trying to initialize the model: 'ValueError: Size of label 'j' for operand 1 (31) does not match previous terms (19).' Can you help me figuring out what the issues is? Thanks!

calebweinreb commented 1 year ago

Please paste the full traceback

Aexolowski commented 1 year ago

ValueError Traceback (most recent call last) Cell In[15], line 3 1 # 1) initialize the model ----> 3 model = kpms.init_model(data, pca=pca, **config()) 5 # optionally modify kappa 6 # model = kpms.update_hypparams(model, kappa=NUMBER)

File ~\AppData\Roaming\Python\Python39\site-packages\jax_moseq\models\keypoint_slds\initialize.py:294, in init_model(data, states, params, hypparams, noise_prior, seed, pca, whiten, PCA_fitting_num_frames, anterior_idxs, posterior_idxs, conf_threshold, error_estimator, trans_hypparams, ar_hypparams, obs_hypparams, cen_hypparams, verbose, exclude_outliers_for_pca, fix_heading, **kwargs) 292 print('Keypoint SLDS: Initializing states') 293 obs_hypparams = hypparams['obs_hypparams'] --> 294 states = init_states(seed, Y, mask, params, noise_prior, 295 obs_hypparams, Y_flat, v, h, fix_heading) 296 else: 297 states = jax.device_put(states)

File ~\AppData\Roaming\Python\Python39\site-packages\jax_moseq\models\keypoint_slds\initialize.py:55, in init_states(seed, Y, mask, params, noise_prior, obs_hypparams, Y_flat, v, h, fix_heading, kwargs) 52 Y_flat, v, h = preprocess_for_pca(Y, fix_heading, kwargs) 54 x = slds.init_continuous_stateseqs(Y_flat, params['Cd']) ---> 55 states = arhmm.init_states(seed, x, mask, params) 57 states['x'] = x 58 states['v'] = v

File ~\AppData\Roaming\Python\Python39\site-packages\jax_moseq\models\arhmm\initialize.py:73, in init_states(seed, x, mask, params, kwargs) 50 def init_states(seed, x, mask, params, kwargs): 51 """ 52 Initialize the latent states of the ARHMM from the 53 data and parameters. (...) 71 State values for each latent variable. 72 """ ---> 73 z = resample_discrete_stateseqs(seed, x, mask, **params) 74 return {'z': z}

[... skipping hidden 14 frame]

File ~\AppData\Roaming\Python\Python39\site-packages\jax_moseq\models\arhmm\gibbs.py:52, in resample_discrete_stateseqs(seed, x, mask, Ab, Q, pi, **kwargs) 49 nlags = get_nlags(Ab) 50 num_samples = mask.shape[0] ---> 52 log_likelihoods = jax.lax.map(partial(ar_loglikelihood, x), (Ab, Q)) 53 , z = jax.vmap(sample_hmm_stateseq, in_axes=(0,na,0,0))( 54 jr.split(seed, num_samples), 55 pi, 56 jnp.moveaxis(log_likelihoods,0,-1), 57 mask.astype(float)[:,nlags:]) 58 return z

[... skipping hidden 12 frame]

File ~\AppData\Roaming\Python\Python39\site-packages\jax_moseq\utils\autoregression.py:22, in ar_log_likelihood(x, params) 20 Ab, Q = params 21 nlags = get_nlags(Ab) ---> 22 mu = apply_ar_params(x, Ab) 23 x = x[..., nlags:, :] 24 return tfd.MultivariateNormalFullCovariance(mu, Q).log_prob(x)

File ~\AppData\Roaming\Python\Python39\site-packages\jax_moseq\utils\autoregression.py:16, in apply_ar_params(x, Ab) 14 nlags = get_nlags(Ab) 15 x_in = get_lags(x, nlags) ---> 16 return apply_affine(x_in, Ab)

File ~\AppData\Roaming\Python\Python39\site-packages\jax_moseq\utils\utils.py:76, in apply_affine(x, Ab) 75 def apply_affine(x, Ab): ---> 76 return jnp.einsum('...ij, ...j->...i', Ab, pad_affine(x))

File ~\AppData\Roaming\Python\Python39\site-packages\jax_src\numpy\lax_numpy.py:2923, in einsum(out, optimize, precision, _use_xeinsum, operands) 2921 ty = next(iter(non_constant_dim_types)) 2922 contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler) -> 2923 operands, contractions = contract_path( 2924 operands, einsum_call=True, useblas=True, optimize=optimize) 2926 contractions = tuple((a, frozenset(b), c) for a, b, c, * in contractions) 2928 _einsum_computation = jax.named_call( 2929 _einsum, name=spec) if spec is not None else _einsum

File ~\AppData\Roaming\Python\Python39\site-packages\opt_einsum\contract.py:238, in contract_path(*operands, **kwargs) 236 size_dict[char] = dim 237 elif dim not in (1, size_dict[char]): --> 238 raise ValueError("Size of label '{}' for operand {} ({}) does not match previous " 239 "terms ({}).".format(char, tnum, size_dict[char], dim)) 240 else: 241 size_dict[char] = dim

ValueError: Size of label 'j' for operand 1 (31) does not match previous terms (19).

calebweinreb commented 1 year ago

Hmm weird. What version of the code do you have?

import keypoint_moseq, jax_moseq
print(keypoint_moseq.__version__, jax_moseq.__version__)

Also can you paste the content of your config?

print(kpms.load_config(project_dir))
Aexolowski commented 1 year ago

I am using a slightly older version of the code I think, I installed it back in March this year.

calebweinreb commented 1 year ago

Does the error still happen if you update the code? (note that you might have to change the notebook in a few places with the update code... see modeling tutorial in the docs)