pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Custom SVI objective for MLE estimation [feature suggestion] #2765

Open yozhikoff opened 3 years ago

yozhikoff commented 3 years ago

Hi Pyro team!

Current tutorial for MLE estimation suggests writing a separate model with pyro.param in its body, which is probably not the most convenient option if one has a large model and wants to check MLE performance. One alternative approach is to use AutoDelta autoguide with an objective that takes into account only is_observed parts of trace, like this

def trace_mle(model, guide, *args, **kwargs):
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
    model_trace = poutine.trace(
        poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)

    model_trace.compute_log_prob(lambda site, node: node['is_observed'])

    loss = 0.0
    for site, node in model_trace.nodes.items():
        if 'is_observed' in node and node['is_observed']:
            loss += -node['log_prob_sum']

    return loss

What are your thoughts on:

  1. Mentioning this in docs?
  2. Adding a TraceMLE loss to pyro?

I can make a PR in both cases.

fritzo commented 3 years ago

Hi @yozhikoff. I've also seen this come up and my usual solution is to add a poutine.mask around relevant sites in the model and add an include_prior kwarg to the model:

def model(data, include_prior=True):
    with poutine.mask(mask=include_prior):
        # Priors (latent sample statements) appear inside of the mask.
        z = pyro.sample("z", dist.Normal(0, 1))
    # Likelihoods (i.e. observe statements) appear outside of the mask.
    pyro.sample("x", dist.Normal(z, 1), obs=data)

# Train with MAP.
guide = AutoDelta(model)
svi = SVI(model, guide, Adam({"lr": 1e-3}), Trace_ELBO())
for step in range(steps):
    svi.step(data)

# Train with MLE.
guide = AutoDelta(model)
svi = SVI(model, guide, Adam({"lr": 1e-3}), Trace_ELBO())
for step in range(steps):
    svi.step(data, include_prior=False)  # <----- this is the only change

This approach also allows fine-grained toggling of different parts of the prior, by separating into multiple masks. WDYT?

Mentioning this in the docs?

We're always happy to accept doc improvements :smile:

Adding a TraceMLE loss to pyro?

Hmm, this a little complex because Pyro often uses is_observed for things other than observations, e.g. for pyro.factor(), pyro.deterministic(), and poutine.reparam() are all treated as auxiliary observations. I don't see any immediate problems...

Would you strongly prefer TraceMLE over the poutine.mask solution?

yozhikoff commented 3 years ago

Thanks, @fritzo! poutine.mask seems to be a good alternative, but I can imagine cases when modifying objective would be preferable (e.g. in complex modular packages if one doesn't want to dig into low-level model code).

Hmm, this a little complex...

At least pyro.deterministic() shouldn't be a problem with its log_prob=0, not sure about the the rest. Anyway, MLE estimation is probably a relatively rare use case for a ppl framework, so you are right here and a docs entry should be enough.

We're always happy to accept doc improvements

I think I'll write add a short note about using AutoDelta in general for MLE/MAP and both poutine.mask and TraceMLE solutions then.