Open AZhou00 opened 2 months ago
Hmm, I think this is what is causing the behaviour you are seeing:
https://github.com/CKrawczyk/MultiHMCGibbs/blob/main/MultiHMCGibbs/multihmcgibbs.py#L188-L199
The code uses the shape of the key to change how it is vmap
ed or not. When it gets one key it assumes sequential, and multiple keys assumes vectorized (the extra info is only shown for sequential). You should be safe just passing in one of the keys in the mcmc.post_warmup_state.rng_key
list into mcmc.run
and having that work as expected in your case.
Hi, thank you for the great package.
I am trying to run a model for a few samples, save the state, and keep sampling. A minimal example will be a long the lines of
This gives the progress bar sample: 100%|██████████| 200/200 [00:02<00:00, 68.52it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96] sample: 100%|██████████| 100/100 [00:00<00:00, 508.15it/s]
As you can see, the progress bar for the second sample some how is not reflecting all the partitions of the kernel.
The code below seems to fix the issue.
sample: 100%|██████████| 200/200 [00:02<00:00, 68.79it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96] sample: 100%|██████████| 100/100 [00:02<00:00, 38.47it/s, 7/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.95]
rng_key
here is just a regular jax rng_key whilemcmc.post_warmup_state.rng_key
are a few keys stacked together. I am not exactly certain how the kernel treats the two cases separately and what is the proper way (or if there are other unintended side effects). I would really appreciate any advice! Thank you in advance. - Alan