Open yozhikoff opened 3 years ago
Hi @yozhikoff. I've also seen this come up and my usual solution is to add a poutine.mask
around relevant sites in the model and add an include_prior
kwarg to the model:
def model(data, include_prior=True):
with poutine.mask(mask=include_prior):
# Priors (latent sample statements) appear inside of the mask.
z = pyro.sample("z", dist.Normal(0, 1))
# Likelihoods (i.e. observe statements) appear outside of the mask.
pyro.sample("x", dist.Normal(z, 1), obs=data)
# Train with MAP.
guide = AutoDelta(model)
svi = SVI(model, guide, Adam({"lr": 1e-3}), Trace_ELBO())
for step in range(steps):
svi.step(data)
# Train with MLE.
guide = AutoDelta(model)
svi = SVI(model, guide, Adam({"lr": 1e-3}), Trace_ELBO())
for step in range(steps):
svi.step(data, include_prior=False) # <----- this is the only change
This approach also allows fine-grained toggling of different parts of the prior, by separating into multiple masks. WDYT?
Mentioning this in the docs?
We're always happy to accept doc improvements :smile:
Adding a
TraceMLE
loss to pyro?
Hmm, this a little complex because Pyro often uses is_observed
for things other than observations, e.g. for pyro.factor()
, pyro.deterministic()
, and poutine.reparam()
are all treated as auxiliary observations. I don't see any immediate problems...
Would you strongly prefer TraceMLE
over the poutine.mask
solution?
Thanks, @fritzo!
poutine.mask
seems to be a good alternative, but I can imagine cases when modifying objective would be preferable (e.g. in complex modular packages if one doesn't want to dig into low-level model code).
Hmm, this a little complex...
At least pyro.deterministic()
shouldn't be a problem with its log_prob=0
, not sure about the the rest.
Anyway, MLE estimation is probably a relatively rare use case for a ppl framework, so you are right here and a docs entry should be enough.
We're always happy to accept doc improvements
I think I'll write add a short note about using AutoDelta
in general for MLE/MAP and both poutine.mask
and TraceMLE
solutions then.
Hi Pyro team!
Current tutorial for MLE estimation suggests writing a separate model with
pyro.param
in its body, which is probably not the most convenient option if one has a large model and wants to check MLE performance. One alternative approach is to useAutoDelta
autoguide with an objective that takes into account onlyis_observed
parts of trace, like thisWhat are your thoughts on:
I can make a PR in both cases.