Closed tims457 closed 4 months ago
I think you can move ignore_alarm
outside of the lazy branch so that two branches have the same latent variables.
That works in this particular case since both branches sample from a normal distribution but not in general. For example, what if one branch has a normal distribution and one has a uniform distribution? This throws the same error. However it is possible to do in Pyro
def sleep_model(data):
feeling_lazy = numpyro.sample("feeling_lazy", dist.Bernoulli(probs=0.9))
ignore_alarm = numpyro.sample("ignore_alarm", dist.Bernoulli(probs=0.8))
def lazy_branch():
amount_slept = numpyro.sample("amount_slept", dist.Normal(loc=8+2*ignore_alarm, scale=1))
return amount_slept
def not_lazy_branch():
# amount_slept numpyro.sample("amount_slept", dist.Normal(loc=6, scale=1)) # this works
amount_slept = numpyro.sample("amount_slept", dist.Uniform(low=6., high=10.)) # this doesn't
return amount_slept
amount_slept = cond(feeling_lazy == 1,
lambda x: lazy_branch(),
lambda x: not_lazy_branch(),
data
)
return amount_slept
This works using Pyro.
def sleep_model():
feeling_lazy = pyro.sample("feeling_lazy", dist.Bernoulli(0.9))
if feeling_lazy:
ignore_alarm = pyro.sample("ignore_alarm", dist.Bernoulli(0.8))
amount_slept = pyro.sample("amount_slept",
dist.Normal(8 + 2 * ignore_alarm, 1))
else:
amount_slept = pyro.sample("amount_slept", dist.Uniform(1, 6))
return amount_slept
Yeah, it is a limitation of cond
. JAX requires two branches to have the same pytree structure. I don't have a good solution for it. Depending on the algorithm, something like this (suitable for MCMC) might be required
def lazy_branch():
amount_slept = numpyro.sample("amount_slept", dist.Uniform(-100., 100.))
numpyro.factor("amount_slept_factor", dist.Normal(loc=8+2*ignore_alarm, scale=1).log_prob(amount_slept))
return amount_slept
Similar for the other branch
numpyro.factor("amount_slept_factor", 0.)
Unfortunate, but thank you.
Just curious on what is the issue that you got?
Closed. Please feel free to ask questions in our forum: forum.pyro.ai
I'm trying to implement this Pyro model in NumPyro, and I'm running into issues with the conditional
feeling_lazy
. The guide I'm trying to follow is herePyro model
I found the
contrib.control_flow.cond
functionality and have tried the following model; however, I'm running into an error from the two different PyTree type structures. Is there an appropriate way to implement the model in NumPyro or is this something still in development?Thanks.
NumPyro model
Error message