pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 235 forks source link

Can't pickle MCMC object (MixedHMC kernel) when `chain_method="parallel"` #1742

Closed msaintja closed 7 months ago

msaintja commented 7 months ago

Hello, I'm encountering an issue when trying to pickle (whether with pickle or dill which I've seen being used here) a numpyro.infer.mcmc.MCMC object after sampling with MixedHMC. I've included some example code for a model, a bit of a mix of what is in the MixedHMC documentation and the Gaussian Mixture Model tutorial:

import dill
import jax
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import HMC, MCMC, MixedHMC

# @title Parameters
PARALLEL_SAMPLING = True # @param {type:"boolean"}
NCHAINS = 2 # @param {type:"integer"}
NSAMPLES = 100 # @param {type:"integer"}

numpyro.set_platform('cpu')
if PARALLEL_SAMPLING:
    numpyro.set_host_device_count(NCHAINS)

key = random.PRNGKey(1)

N = 20
K = 2 
w = jnp.array([0.7, 0.3])
μ = jnp.array([-2, 3])

def model(data):
    # Global variables.
    weights = numpyro.sample("weights", dist.Dirichlet(jnp.ones(K)/K))
    with numpyro.plate("K", K):
        locs = numpyro.sample("locs", dist.Normal(0.0, 10.0))

    with numpyro.plate("data", len(data)):
        # Local variables.
        k = numpyro.sample("k", dist.Categorical(weights))
        numpyro.sample("obs", dist.Normal(locs[k], 1.0), obs=data)

y = random.choice(key, jnp.arange(K), shape=(N,), p=w)
data = random.normal(key, shape=(N,)) + jnp.take(μ, y)

kernel = MixedHMC(HMC(model))
mcmc = MCMC(kernel, num_warmup=NSAMPLES//2, num_samples=NSAMPLES, num_chains=NCHAINS, progress_bar=True, chain_method=("parallel" if PARALLEL_SAMPLING else "sequential"))
mcmc.run(key, data)

with open("chains_test.pkl", 'wb') as f:
    dill.dump(mcmc, f)

The error thrown is a ConcretizationTypeError (from JAX), and seems to point towards the line with the Dirichlet distribution.

Specifically, the error that I get is the following:

/usr/local/lib/python3.10/dist-packages/jax/_src/core.py in __reduce__(self)
    794   # raises a useful error on attempts to pickle a Tracer.
    795   def __reduce__(self):
--> 796     raise ConcretizationTypeError(
    797       self, ("The error occurred in the __reduce__ method, which may "
    798              "indicate an attempt to serialize/pickle a traced value."))

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float32[2].
The error occurred in the __reduce__ method, which may indicate an attempt to serialize/pickle a traced value.
This DynamicJaxprTracer was created on line <ipython-input-5-e1929d5da667>:8 (model)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

The pickling issue does not occur:

I have this issue locally with:

JAX version:  0.4.17.dev20231003+g4cb5eeee5
Numpyro version:  0.13.2

and have reproduced this on Colab with JAX 0.4.23 and the latest commit for numpyro.

Apologies if I'm missing anything obvious, I have only started using NumPyro recently.

fehiepsi commented 7 months ago

Hi @msaintja, a simple fix for this type of issue is to remove unnecessary attributes in the __getstate__ method of the object. Running mcmc.sampler.__getstate__() gives me some jax tracers in the object. Do you want to make a PR to remove them?

msaintja commented 7 months ago

Thank you for the pointers @fehiepsi! I'm not too sure about the underlying causes (some attributes are tracers with parallel chains but not in the sequential case), but your suggestion has allowed me to pinpoint which elements were throwing the error.

fehiepsi commented 7 months ago

Closed via #1746