lindermanlab / ssm-jax

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend
MIT License
57 stars 7 forks source link

Hamiltonian Monte Carlo (HMC) for HMM example #28

Open slinderman opened 2 years ago

slinderman commented 2 years ago

SSM's hidden Markov model (HMM) objects expose a function to compute the marginal likelihood of the data, summing over the discrete latent states. This function can be automatically differentiated with jax.grad. Use Tensorflow Probability's Hamiltonian Monte Carlo (HMC) functionality to perform Bayesian inference over HMM parameters, using the marginal likelihood and a prior on parameter values.