blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
758 stars 97 forks source link

Add HMC Swindles #38

Open rlouf opened 3 years ago

rlouf commented 3 years ago

There are several algorithms in https://arxiv.org/abs/2001.05033 that should be simple to implement.:

def hmc_coupled(rng_key, states):
    states, infos = jax.vmap(kernel, in_axis=(None, 0))(rng_key, states)
    return states, infos
rlouf commented 1 year ago

Seemingly unrelated, but Metropolis-within-Gibbs has been shown to work well (https://github.com/blackjax-devs/blackjax/discussions/275). So this is definitely feasible.