This package adds a new HMC-within-Gibbs sampler to Numpyro. Unlike the HMCGibbs
sampler currently available, this sampler is for situations where you do not have an analytic form for one of your conditioned distributions. Instead, it uses an HMC/NUTS sampler to estimate draws from each of the conditioned distributions.
To use MultiHMCGibbs
you need to create a list of HMC or NUTS kernels that wrap the same model, but each can have its own keywords such as target_accept_prob
or max_tree_depth
. The other argument is a list of lists containing the free parameters for each of the inner kernels.
Internally the sampler will:
Documentation: https://ckrawczyk.github.io/MultiHMCGibbs/
GitHub: https://github.com/CKrawczyk/MultiHMCGibbs
You can install the package with pip
after cloning the repository.
git clone https://github.com/CKrawczyk/MultiHMCGibbs.git
cd MultiHMCGibbs
pip install .
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from MultiHMCGibbs import MultiHMCGibbs
def model():
x = numpyro.sample("x", dist.Normal(0.0, 2.0))
y = numpyro.sample("y", dist.Normal(0.0, 2.0))
numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))
inner_kernels = [
NUTS(model),
NUTS(model)
]
outer_kernel = MultiHMCGibbs(
inner_kernels,
[['y'], ['x']]
)
mcmc = MCMC(
kernel,
num_warmup=100,
num_samples=100,
progress_bar=False
)
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()
Install all the development dependencies:
pip install -e .[dev]
Run tests with:
coverage run
coverage report
Build documentation with:
./build_docs
If you use this sampler in your publication you can cite the software as:
Coleman Krawczyk. (2024). CKrawczyk/MultiHMCGibbs: v1.0.0 (v1.0.0). Zenodo. https://doi.org/10.5281/zenodo.12167630
Or with bibtex:
@software{coleman_krawczyk_2024_12167630,
author = {Coleman Krawczyk},
title = {CKrawczyk/MultiHMCGibbs: v1.0.0},
month = jun,
year = 2024,
publisher = {Zenodo},
version = {v1.0.0},
doi = {10.5281/zenodo.12167630},
url = {https://doi.org/10.5281/zenodo.12167630}
}
Full citation information can be found on https://zenodo.org/records/12167630