pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 985 forks source link

FR Return deterministics in AutoGuide.quantiles() by default #2848

Open fonnesbeck opened 3 years ago

fonnesbeck commented 3 years ago

Currently, nodes that are wrapped in determinstic are not traced, and obtaining inferences from these nodes must be run through a Predictive object. When a user deliberately creates these deterministics they are signalling that their values are of central interest, so it would be convenient to provide them immediately and not force users to execute additional steps.

fritzo commented 3 years ago

Hi @fonnesbeck, can you provide more details to the inference you are performing? It sounds like you might be using MCMC?

Assuming you are using MCMC: One issue with saving all-information-by-default is the memory overhead. Currently MCMC saves exactly the minimal information needed to reproduce traces, and even that is becoming too large for memory in many cases (see #2843). I'm curious, do you need the full set of samples at your deterministic sites, or would some moments of samples suffice, e.g. mean and variance or covariance?

fonnesbeck commented 3 years ago

Sorry for the lack of detail: I am using SVI, not NUTS/MCMC, so in my case moments would do, but maybe not in all cases. I agree that you don't want to save all intermediate calculations in the trace, but those deliberately wrapped in deterministic are a different story; they are often more important to the user than many of the stochastics. So, not all the information, just the information the user has indicated is relevant.

fritzo commented 3 years ago

@fonnesbeck thanks, interesting, I wasn't aware that SVI was discarding deterministic sites. Could you provide a little code snippet to show how what is going wrong? Is it specifically using AutoGuide.median() or AutoGuide.quantiles()? I'm a little confused since the generic SVI workflow doesn't return any traces.

fonnesbeck commented 3 years ago

I will try and put together a toy example when I get a chance, but I am using the quantiles method on AutoDiagonalNormal in my particular case. This does not include any deterministics.

fritzo commented 3 years ago

Thanks @fonnesbeck, that makes sense. No need for an example.

I think we can support deterministic sites via a new .predict_quantiles() method.

class AutoGuide(PyroModule):
    ...
    @torch.no_grad()
    def predict_quantiles(self, quantiles, *args, **kwargs):
        data = self.quantiles(quantiles, *args, **kwargs)
        vectorize = pyro.plate(...)  # as in .quantiles() method
        model = condition(vectorize(self.model), data)
        trace = poutine.trace(model).get_trace(*args, **kwargs)
        return {
            name: site["value"]
            for name, site in trace.nodes.items()
            if site["type"] == "sample"
            if not site_is_subsample(site)
        }

    def predict_median(...):
        # ...similar...

Note there would be a couple difficulties inspecting deterministic sites by default:

  1. Deterministic sites require running not only the guide but the model. While it is easy to vectorize the guide, not all models support vectorization. I think the default should make minimal model assumptions.
  2. Some models have enormous observation statements. In those cases computing simple guide quantiles of latent variables can be much cheaper than running the model to generate additional deterministic statements. Again I think the default behavior should be the cheap minimal guide-side computation.

For these reasons I think this issue deserves a new separate interface.