pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.18k stars 239 forks source link

Samples are outside the support for DiscreteUniform distribution #1834

Open Deathn0t opened 3 months ago

Deathn0t commented 3 months ago

Hello,

I noticed that samples have value outside the support for DiscreteUniform distribution. Here is a simple reproducible example:

import jax.random
import numpyro

import numpyro.distributions as dist

from numpyro.infer import HMC, MCMC, MixedHMC

def model():
    x = numpyro.sample("x", dist.DiscreteUniform(1, 2))

num_samples = 10
kernel = HMC(model, trajectory_length=1.2)
kernel = MixedHMC(kernel, num_discrete_updates=20)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples, progress_bar=False)
key = jax.random.PRNGKey(0)
mcmc.run(key)
samples = mcmc.get_samples()

print(samples)

Which outputs:

{'x': Array([1, 1, 0, 0, 0, 0, 0, 0, 0, 1], dtype=int32)}

I was expecting values of x to be in [1,2].

Am I using it wrongly or is it a real bug?

Thank you very much for your help.

fehiepsi commented 3 months ago

Thanks @Deathn0t! It is a bug at this line https://github.com/pyro-ppl/numpyro/blob/f6eb6ce152bd8e903dd56eeb5909ae0b59e24abe/numpyro/infer/hmc_gibbs.py#L283

We should pass in the enumerate support there. Something like

support_size = enumerate_support.shape[0]
proposal_idx = random.randint(rng_proposal, (), minval=0, maxval=support_size)
proposal = enumerate_support[proposal_idx]

Do you want to try to fix the issue?

Deathn0t commented 3 months ago

Hi @fehiepsi , thank you for the hint! I will try to look at it today and keep you updated if I am blocked. If I see things working I will open PR.

Deathn0t commented 3 months ago

Hi @fehiepsi , I had a look at it today. It seems the Gibbs proposal is used instead of the RW:

https://github.com/pyro-ppl/numpyro/blob/f6eb6ce152bd8e903dd56eeb5909ae0b59e24abe/numpyro/infer/hmc_gibbs.py#L219

For simplicity and minimal code changes I was thinking maybe to do the mapping to enumerate_support values here on z_discrete: https://github.com/pyro-ppl/numpyro/blob/f6eb6ce152bd8e903dd56eeb5909ae0b59e24abe/numpyro/infer/mixed_hmc.py#L304

what do you think?

I tried the following and it seems to work:

z_discrete = jax.tree.map(
    lambda idx, support: support[idx],
    z_discrete,
    self._support_enumerates,
)