Closed msaintja closed 7 months ago
Hi @msaintja, a simple fix for this type of issue is to remove unnecessary attributes in the __getstate__
method of the object. Running mcmc.sampler.__getstate__()
gives me some jax tracers in the object. Do you want to make a PR to remove them?
Thank you for the pointers @fehiepsi! I'm not too sure about the underlying causes (some attributes are tracers with parallel chains but not in the sequential case), but your suggestion has allowed me to pinpoint which elements were throwing the error.
Closed via #1746
Hello, I'm encountering an issue when trying to pickle (whether with
pickle
ordill
which I've seen being used here) anumpyro.infer.mcmc.MCMC
object after sampling with MixedHMC. I've included some example code for a model, a bit of a mix of what is in the MixedHMC documentation and the Gaussian Mixture Model tutorial:The error thrown is a
ConcretizationTypeError
(from JAX), and seems to point towards the line with the Dirichlet distribution.Specifically, the error that I get is the following:
The pickling issue does not occur:
chain_method="sequential"
),I have this issue locally with:
and have reproduced this on Colab with JAX 0.4.23 and the latest commit for numpyro.
Apologies if I'm missing anything obvious, I have only started using NumPyro recently.