CKrawczyk / MultiHMCGibbs

A numpyro Gibbs sampler that uses conditioned HMC kernels for each step.
Apache License 2.0
4 stars 0 forks source link

rng_key input for sequential MCMC #9

Open AZhou00 opened 2 months ago

AZhou00 commented 2 months ago

Hi, thank you for the great package.

I am trying to run a model for a few samples, save the state, and keep sampling. A minimal example will be a long the lines of

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from MultiHMCGibbs import MultiHMCGibbs

def model():
     x = numpyro.sample("x", dist.Normal(0.0, 2.0))
     y = numpyro.sample("y", dist.Normal(0.0, 2.0))
     numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

inner_kernels = [
    NUTS(model),
    NUTS(model)
]
outer_kernel = MultiHMCGibbs(
    inner_kernels,
    [['y'], ['x']]
)

mcmc = MCMC(
    outer_kernel,
    num_warmup=100,
    num_samples=100,
    progress_bar=True
)
mcmc.run(rng_key)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(mcmc.post_warmup_state.rng_key)

This gives the progress bar sample: 100%|██████████| 200/200 [00:02<00:00, 68.52it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96] sample: 100%|██████████| 100/100 [00:00<00:00, 508.15it/s]

As you can see, the progress bar for the second sample some how is not reflecting all the partitions of the kernel.

The code below seems to fix the issue.

mcmc = MCMC(
    outer_kernel,
    num_warmup=100,
    num_samples=100,
    progress_bar=True
)
mcmc.run(rng_key)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(rng_key)

sample: 100%|██████████| 200/200 [00:02<00:00, 68.79it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96] sample: 100%|██████████| 100/100 [00:02<00:00, 38.47it/s, 7/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.95]

rng_key here is just a regular jax rng_key while mcmc.post_warmup_state.rng_key are a few keys stacked together. I am not exactly certain how the kernel treats the two cases separately and what is the proper way (or if there are other unintended side effects). I would really appreciate any advice! Thank you in advance. - Alan

CKrawczyk commented 1 month ago

Hmm, I think this is what is causing the behaviour you are seeing:

https://github.com/CKrawczyk/MultiHMCGibbs/blob/main/MultiHMCGibbs/multihmcgibbs.py#L188-L199

The code uses the shape of the key to change how it is vmaped or not. When it gets one key it assumes sequential, and multiple keys assumes vectorized (the extra info is only shown for sequential). You should be safe just passing in one of the keys in the mcmc.post_warmup_state.rng_key list into mcmc.run and having that work as expected in your case.