pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 235 forks source link

Correct control_flow.cond usage #1763

Closed tims457 closed 4 months ago

tims457 commented 6 months ago

I'm trying to implement this Pyro model in NumPyro, and I'm running into issues with the conditional feeling_lazy. The guide I'm trying to follow is here

Pyro model

def sleep_model():
    # Very likely to feel lazy
    feeling_lazy = sample("feeling_lazy", dist.Bernoulli(0.9))
    if feeling_lazy:
        # Only going to (possibly) ignore my alarm if I'm feeling lazy
        ignore_alarm = sample("ignore_alarm", dist.Bernoulli(0.8))
        # Will sleep more if I ignore my alarm
        amount_slept = sample("amount_slept",
                              dist.Normal(8 + 2 * ignore_alarm, 1))
    else:
        amount_slept = sample("amount_slept", dist.Normal(6, 1))
    return amount_slept

I found the contrib.control_flow.cond functionality and have tried the following model; however, I'm running into an error from the two different PyTree type structures. Is there an appropriate way to implement the model in NumPyro or is this something still in development?

Thanks.

NumPyro model

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import cond

def sleep_model(data):
    feeling_lazy = numpyro.sample("feeling_lazy", dist.Bernoulli(probs=0.9))

    def lazy_branch():
        ignore_alarm = numpyro.sample("ignore_alarm", dist.Bernoulli(probs=0.8))
        amount_slept =  numpyro.sample("amount_slept", dist.Normal(loc=8+2*ignore_alarm, scale=1))
        return amount_slept

    def not_lazy_branch():
        return numpyro.sample("amount_slept", dist.Normal(loc=6, scale=1))

    amount_slept = cond(feeling_lazy == 1, 
                        lambda x: lazy_branch(), 
                        lambda x: not_lazy_branch(),
                        data
    )
    return amount_slept

# with jax.checking_leaks():
data = jnp.ones(10)
numpyro.render_model(sleep_model, model_args=(data,))

Error message

    "name": "TypeError",
    "message": "true_fun and false_fun output must have same type structure, got PyTreeDef((*, CustomNode(PytreeTrace[({'ignore_alarm': {'_control_flow_done': True, 'type': 'sample', 'name': 'ignore_alarm', 'kwargs': {'rng_key': None, 'sample_shape': ()}, 'scale': None, 'is_observed': False, 'cond_indep_stack': [], 'infer': {}}, 'amount_slept': {'_control_flow_done': True, 'type': 'sample', 'name': 'amount_slept', 'kwargs': {'rng_key': None, 'sample_shape': ()}, 'scale': None, 'is_observed': False, 'cond_indep_stack': [], 'infer': {}}}, ['ignore_alarm', 'amount_slept'])], [{'amount_slept': {'args': (), 'fn': CustomNode(Normal[((), None, ())], [*, *]), 'intermediates': [], 'value': *}, 'ignore_alarm': {'args': (), 'fn': CustomNode(BernoulliProbs[((), None, ())], [*]), 'intermediates': [], 'value': *}}]))) and PyTreeDef((*, CustomNode(PytreeTrace[({'amount_slept': {'_control_flow_done': True, 'type': 'sample', 'name': 'amount_slept', 'kwargs': {'rng_key': None, 'sample_shape': ()}, 'scale': None, 'is_observed': False, 'cond_indep_stack': [], 'infer': {}}}, ['amount_slept'])], [{'amount_slept': {'args': (), 'fn': CustomNode(Normal[((), None, ())], [*, *]), 'intermediates': [], 'value': *}}]))).",
fehiepsi commented 6 months ago

I think you can move ignore_alarm outside of the lazy branch so that two branches have the same latent variables.

tims457 commented 6 months ago

That works in this particular case since both branches sample from a normal distribution but not in general. For example, what if one branch has a normal distribution and one has a uniform distribution? This throws the same error. However it is possible to do in Pyro

def sleep_model(data):
    feeling_lazy = numpyro.sample("feeling_lazy", dist.Bernoulli(probs=0.9))
    ignore_alarm = numpyro.sample("ignore_alarm", dist.Bernoulli(probs=0.8))

    def lazy_branch():
        amount_slept =  numpyro.sample("amount_slept", dist.Normal(loc=8+2*ignore_alarm, scale=1))
        return amount_slept

    def not_lazy_branch():
        # amount_slept numpyro.sample("amount_slept", dist.Normal(loc=6, scale=1)) # this works
        amount_slept = numpyro.sample("amount_slept", dist.Uniform(low=6., high=10.)) # this doesn't
        return amount_slept

    amount_slept = cond(feeling_lazy == 1, 
                        lambda x: lazy_branch(), 
                        lambda x: not_lazy_branch(),
                        data
    )
    return amount_slept

This works using Pyro.

def sleep_model():
    feeling_lazy = pyro.sample("feeling_lazy", dist.Bernoulli(0.9))

    if feeling_lazy:
        ignore_alarm = pyro.sample("ignore_alarm", dist.Bernoulli(0.8))
        amount_slept = pyro.sample("amount_slept",
                              dist.Normal(8 + 2 * ignore_alarm, 1))
    else:
        amount_slept = pyro.sample("amount_slept", dist.Uniform(1, 6))

    return amount_slept
fehiepsi commented 6 months ago

Yeah, it is a limitation of cond. JAX requires two branches to have the same pytree structure. I don't have a good solution for it. Depending on the algorithm, something like this (suitable for MCMC) might be required

def lazy_branch():
        amount_slept =  numpyro.sample("amount_slept", dist.Uniform(-100., 100.))
        numpyro.factor("amount_slept_factor", dist.Normal(loc=8+2*ignore_alarm, scale=1).log_prob(amount_slept))
        return amount_slept

Similar for the other branch

numpyro.factor("amount_slept_factor", 0.)
tims457 commented 6 months ago

Unfortunate, but thank you.

fehiepsi commented 6 months ago

Just curious on what is the issue that you got?

fehiepsi commented 4 months ago

Closed. Please feel free to ask questions in our forum: forum.pyro.ai