When using MixedHMC, and sampling with option chain_method="parallel" with more than one chain, using dill/pickle to save the MCMC object resulted in a ConcretizationTypeError, as some attributes may contain JAX tracers.
From limited testing, these would be the _support_sizes_flat attribute of the MCMC object, as well as some attributes of _prototype_trace[...]["fn"].
Setting those to None in the state copy passed to dill would allow to circumvent the JAX error.
I'm not sure how critical it is to maintain those attributes in a pickled save, however a simple dump-load-run seems to work fine with the example code presented in #1742:
with open("chains_test.pkl", 'wb') as f:
dill.dump(mcmc, f)
with open("chains_test.pkl", 'rb') as f:
mcmc = dill.load(f)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(key, data)
This PR addresses issue #1742.
When using
MixedHMC
, and sampling with optionchain_method="parallel"
with more than one chain, usingdill
/pickle
to save theMCMC
object resulted in aConcretizationTypeError
, as some attributes may contain JAX tracers. From limited testing, these would be the_support_sizes_flat
attribute of the MCMC object, as well as some attributes of_prototype_trace[...]["fn"]
.Setting those to
None
in the state copy passed todill
would allow to circumvent the JAX error. I'm not sure how critical it is to maintain those attributes in a pickled save, however a simple dump-load-run seems to work fine with the example code presented in #1742: