Open rlouf opened 3 years ago
There are several algorithms in https://arxiv.org/abs/2001.05033 that should be simple to implement.:
rng_key
def hmc_coupled(rng_key, states): states, infos = jax.vmap(kernel, in_axis=(None, 0))(rng_key, states) return states, infos
Seemingly unrelated, but Metropolis-within-Gibbs has been shown to work well (https://github.com/blackjax-devs/blackjax/discussions/275). So this is definitely feasible.
There are several algorithms in https://arxiv.org/abs/2001.05033 that should be simple to implement.:
rng_key
;rng_key
.