lindermanlab / ssm-jax

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend
MIT License
58 stars 7 forks source link

Hamiltonian Monte Carlo for Gaussian LDS example #29

Open slinderman opened 2 years ago

slinderman commented 2 years ago

SSM's Gaussian linear dynamical system (LDS) objects expose a function to compute the marginal likelihood of the data, integrating over the continuous latent states. This function can be automatically differentiated with jax.grad. Use Tensorflow Probability's Hamiltonian Monte Carlo (HMC) functionality to perform Bayesian inference over LDS parameters, using the marginal likelihood and a prior on parameter values.