dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
77 stars 28 forks source link

reshaping problem in Gibbs sampler for n_lag=1 #168

Closed heikestein closed 2 months ago

heikestein commented 2 months ago

When running a full MoSeq model with n_lag=1, the Gibbs sampling of latents x leads to a shape mismatch error for concatenation in

# =========================================================================
# 5. Reformat sampled trajectories back into L'th order AR dynamics in R^D
# =========================================================================
    x = jnp.concatenate(
        [
            x[:, 0, : (n_lags - 1) * latent_dim].reshape(-1, n_lags - 1, latent_dim),
            x[:, :, -latent_dim:],
        ],
        axis=1,
    )
calebweinreb commented 2 months ago

Thanks for the heads up. This is addressed here https://github.com/dattalab/jax-moseq/pull/36