pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Make poutine.condition commute with subsample #3067

Open gioelelm opened 2 years ago

gioelelm commented 2 years ago

Pyro seems to encourage the use of effect handlers and, indeed, I generally prefer to use poutine.condition rather than the keyword argument obs in pyro.sample.

Typical use cases where this pays off are many. For example it makes sense to use the effect handler to quickly convert a generative model into a conditioned model for inference or to compare a model where a variable is assumed known vs one where it is latent.

However, poutine.condition does not seem to be compatible/commute with subsampling (e.g. to train using mini-batches). This means that if a user wants to use mini-batch training they are forced to change framework and use obs.

To clarify what I mean, consider the code below, it does not return a valid model that can be used for inference.

def generative_model():
    feature_plate = pyro.plate("feature_plate", size=90, dim=-1)
    sample_plate = pyro.plate("sample_plate", size=1000, subsample_size=20, dim=-2)
    mu = pyro.sample("mu", pyro.distributions.Normal(0, 1))
    with feature_plate:
        with sample_plate:
            X = pyro.sample("X", pyro.distributions.Normal(mu, 1))

Xdata = pyro.distributions.Normal(0, 1).sample((1000,90))
inference_model = poutine.condition(generative_model, data={"X": Xdata})

Is there a way to use the effect handler logic to condition a model (i.e. poutine.condition) still allow mini-batch training (automatic or custom)?

I was thinking a way of doing it that would be consistent with pyro's effect handler composability would be:

inference_model = poutine.condition(generative_model, data={"X": pyro.subsample(Xdata, 0)})

That of course does not work because the. pyro.subsample statement is not inside a pyro.plate.

I think this would be a feature worth adding for consistency with the effecr handler framework.

(from this forum post)

eb8680 commented 2 years ago

@gioelelm good observation! I think we could address this by applying pyro.subsample to observed values inside of ConditionMessenger._pyro_sample and pyro.sample (before any handlers are applied).

gioelelm commented 2 years ago

To add some extra perspective, if a latent variable has the same shape as samples, it is also quite difficult to get to fit the model above in the right way even with obs.

In particular, because of the existence of a latent variable I could not find a way to do it successfully using subsample_size, and I had to use subsample with attention to different non-obvious details.

Code example, with notes on the important parts, I had to learn by many trial-and-errors:

from pyro.infer import Trace_ELBO, SVI, Predictive
from pyro.infer.autoguide import AutoDiagonalNormal

def inference_model(data, ind=None): 
    feature_plate = pyro.plate("feature_plate", size=90, dim=-1)
    sample_plate = pyro.plate("sample_plate", size=1000, subsample=ind, dim=-2) # Note1: I need to use subsample here
    u = pyro.sample("u", pyro.distributions.Normal(0, 10))
    with sample_plate:
        mu = pyro.sample("mu", pyro.distributions.Normal(u, 1))
    with feature_plate:
        with sample_plate as ind:
            X = pyro.sample("X", pyro.distributions.Normal(mu[ind], 1), obs=data[find]) # Note2: I need to index the data AND the latent variable

# Create fake data
data = pyro.distributions.Normal(0, 1).sample((1000,90)) + torch.linspace(0,10, 1000)[:, None]

# Do inference
pyro.clear_param_store()
guide = AutoDiagonalNormal(inference_model)
optim = pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)})
elbo = Trace_ELBO()
svi = SVI(inference_model, guide, optim, loss=elbo)
# NOTE 3: Important step below!
# Here I compute the loss the first time for the entire data with ind=None
# this is the only way I was able to get the trace initialize correctly so to consider the full latent variable mu
init_loss = svi.loss(inference_model, guide, data, ind=None)

for i in range(10000):
    loss = svi.step(data, torch.randint(0, 1000, (40,))) # Note 4:  here I can feed random indices for a batch-training

# verify the shape of the posterior
pps = Predictive(model=inference_model, guide=guide,
                 num_samples=30, return_sites=("mu",))(data)
pps["mu"].shape