Open gioelelm opened 2 years ago
@gioelelm good observation! I think we could address this by applying pyro.subsample
to observed values inside of ConditionMessenger._pyro_sample
and pyro.sample
(before any handlers are applied).
To add some extra perspective, if a latent variable has the same shape as samples, it is also quite difficult to get to fit the model above in the right way even with obs
.
In particular, because of the existence of a latent variable I could not find a way to do it successfully using subsample_size
, and I had to use subsample
with attention to different non-obvious details.
Code example, with notes on the important parts, I had to learn by many trial-and-errors:
from pyro.infer import Trace_ELBO, SVI, Predictive
from pyro.infer.autoguide import AutoDiagonalNormal
def inference_model(data, ind=None):
feature_plate = pyro.plate("feature_plate", size=90, dim=-1)
sample_plate = pyro.plate("sample_plate", size=1000, subsample=ind, dim=-2) # Note1: I need to use subsample here
u = pyro.sample("u", pyro.distributions.Normal(0, 10))
with sample_plate:
mu = pyro.sample("mu", pyro.distributions.Normal(u, 1))
with feature_plate:
with sample_plate as ind:
X = pyro.sample("X", pyro.distributions.Normal(mu[ind], 1), obs=data[find]) # Note2: I need to index the data AND the latent variable
# Create fake data
data = pyro.distributions.Normal(0, 1).sample((1000,90)) + torch.linspace(0,10, 1000)[:, None]
# Do inference
pyro.clear_param_store()
guide = AutoDiagonalNormal(inference_model)
optim = pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)})
elbo = Trace_ELBO()
svi = SVI(inference_model, guide, optim, loss=elbo)
# NOTE 3: Important step below!
# Here I compute the loss the first time for the entire data with ind=None
# this is the only way I was able to get the trace initialize correctly so to consider the full latent variable mu
init_loss = svi.loss(inference_model, guide, data, ind=None)
for i in range(10000):
loss = svi.step(data, torch.randint(0, 1000, (40,))) # Note 4: here I can feed random indices for a batch-training
# verify the shape of the posterior
pps = Predictive(model=inference_model, guide=guide,
num_samples=30, return_sites=("mu",))(data)
pps["mu"].shape
Pyro seems to encourage the use of effect handlers and, indeed, I generally prefer to use
poutine.condition
rather than the keyword argumentobs
inpyro.sample
.Typical use cases where this pays off are many. For example it makes sense to use the effect handler to quickly convert a generative model into a conditioned model for inference or to compare a model where a variable is assumed known vs one where it is latent.
However,
poutine.condition
does not seem to be compatible/commute with subsampling (e.g. to train using mini-batches). This means that if a user wants to use mini-batch training they are forced to change framework and useobs
.To clarify what I mean, consider the code below, it does not return a valid model that can be used for inference.
Is there a way to use the effect handler logic to condition a model (i.e.
poutine.condition
) still allow mini-batch training (automatic or custom)?I was thinking a way of doing it that would be consistent with pyro's effect handler composability would be:
That of course does not work because the.
pyro.subsample
statement is not inside apyro.plate
.I think this would be a feature worth adding for consistency with the effecr handler framework.
(from this forum post)