pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

Fix for pickling an MCMC object with HMCGibbs (and MixedHMC) samplers and parallel chains #1746

Closed msaintja closed 4 months ago

msaintja commented 4 months ago

This PR addresses issue #1742.

When using MixedHMC, and sampling with option chain_method="parallel" with more than one chain, using dill/pickle to save the MCMC object resulted in a ConcretizationTypeError, as some attributes may contain JAX tracers. From limited testing, these would be the _support_sizes_flat attribute of the MCMC object, as well as some attributes of _prototype_trace[...]["fn"].

Setting those to None in the state copy passed to dill would allow to circumvent the JAX error. I'm not sure how critical it is to maintain those attributes in a pickled save, however a simple dump-load-run seems to work fine with the example code presented in #1742:

with open("chains_test.pkl", 'wb') as f:
    dill.dump(mcmc, f)
with open("chains_test.pkl", 'rb') as f:
    mcmc = dill.load(f)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(key, data)