pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.09k stars 227 forks source link

Allow for more general chain_method in MCMC #1825

Closed fehiepsi closed 1 week ago

fehiepsi commented 1 week ago

Fixes #1725

Currently, we only support "vectorized" chain method in HMC/NUTS. This PR allows chain_method to be a callable, so that users can use vectorized method via chain_method=jax.vmap.

I found this is simpler than supporting vmap explicitly for each kernel.