Open Deathn0t opened 3 months ago
Thanks @Deathn0t! It is a bug at this line https://github.com/pyro-ppl/numpyro/blob/f6eb6ce152bd8e903dd56eeb5909ae0b59e24abe/numpyro/infer/hmc_gibbs.py#L283
We should pass in the enumerate support there. Something like
support_size = enumerate_support.shape[0]
proposal_idx = random.randint(rng_proposal, (), minval=0, maxval=support_size)
proposal = enumerate_support[proposal_idx]
Do you want to try to fix the issue?
Hi @fehiepsi , thank you for the hint! I will try to look at it today and keep you updated if I am blocked. If I see things working I will open PR.
Hi @fehiepsi , I had a look at it today. It seems the Gibbs proposal is used instead of the RW:
For simplicity and minimal code changes I was thinking maybe to do the mapping to enumerate_support
values here on z_discrete
:
https://github.com/pyro-ppl/numpyro/blob/f6eb6ce152bd8e903dd56eeb5909ae0b59e24abe/numpyro/infer/mixed_hmc.py#L304
what do you think?
I tried the following and it seems to work:
z_discrete = jax.tree.map(
lambda idx, support: support[idx],
z_discrete,
self._support_enumerates,
)
Hello,
I noticed that samples have value outside the support for
DiscreteUniform
distribution. Here is a simple reproducible example:Which outputs:
I was expecting values of
x
to be in [1,2].Am I using it wrongly or is it a real bug?
Thank you very much for your help.