probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
634 stars 70 forks source link

NaNs returned by`lgssm_posterior_sample` #320

Closed calebweinreb closed 1 year ago

calebweinreb commented 1 year ago

We have been trying to incorporate dynamax into jax-moseq, a tool for unsupervised analysis of animal behavior. Specifically, we would like to replace our custom Kalman sampling code with the lgssm_posterior_sample method in dynamax. @ezhang94 has already done all the heavy lifting and tested it on some small-scale examples. However we are still getting all-NaN outputs for more realistically-sized datasets.

It seems like the problem can be solved by adding a small amount to the diagonal of the posterior covariance during each backward sampling step. Below is a brief recipe to reproduce the issue and a diagnosis of where the NaNs first appear.

slinderman commented 1 year ago

Thanks for digging into this and finding/fixing these issues. I've merged your PRs. Hope we can get jax-moseq to work with dynamax!