Currently, we only support "vectorized" chain method in HMC/NUTS. This PR allows chain_method to be a callable, so that users can use vectorized method via chain_method=jax.vmap.
I found this is simpler than supporting vmap explicitly for each kernel.
Fixes #1725
Currently, we only support "vectorized" chain method in HMC/NUTS. This PR allows chain_method to be a callable, so that users can use vectorized method via
chain_method=jax.vmap
.I found this is simpler than supporting vmap explicitly for each kernel.