Closed heikestein closed 2 months ago
When running a full MoSeq model with n_lag=1, the Gibbs sampling of latents x leads to a shape mismatch error for concatenation in
# ========================================================================= # 5. Reformat sampled trajectories back into L'th order AR dynamics in R^D # ========================================================================= x = jnp.concatenate( [ x[:, 0, : (n_lags - 1) * latent_dim].reshape(-1, n_lags - 1, latent_dim), x[:, :, -latent_dim:], ], axis=1, )
Thanks for the heads up. This is addressed here https://github.com/dattalab/jax-moseq/pull/36
When running a full MoSeq model with n_lag=1, the Gibbs sampling of latents x leads to a shape mismatch error for concatenation in