Open fonnesbeck opened 3 years ago
Hi @fonnesbeck, can you provide more details to the inference you are performing? It sounds like you might be using MCMC?
Assuming you are using MCMC: One issue with saving all-information-by-default is the memory overhead. Currently MCMC saves exactly the minimal information needed to reproduce traces, and even that is becoming too large for memory in many cases (see #2843). I'm curious, do you need the full set of samples at your deterministic sites, or would some moments of samples suffice, e.g. mean and variance or covariance?
Sorry for the lack of detail: I am using SVI, not NUTS/MCMC, so in my case moments would do, but maybe not in all cases. I agree that you don't want to save all intermediate calculations in the trace, but those deliberately wrapped in deterministic
are a different story; they are often more important to the user than many of the stochastics. So, not all the information, just the information the user has indicated is relevant.
@fonnesbeck thanks, interesting, I wasn't aware that SVI was discarding deterministic sites. Could you provide a little code snippet to show how what is going wrong? Is it specifically using AutoGuide.median()
or AutoGuide.quantiles()
? I'm a little confused since the generic SVI workflow doesn't return any traces.
I will try and put together a toy example when I get a chance, but I am using the quantiles
method on AutoDiagonalNormal
in my particular case. This does not include any deterministics.
Thanks @fonnesbeck, that makes sense. No need for an example.
I think we can support deterministic sites via a new .predict_quantiles()
method.
class AutoGuide(PyroModule):
...
@torch.no_grad()
def predict_quantiles(self, quantiles, *args, **kwargs):
data = self.quantiles(quantiles, *args, **kwargs)
vectorize = pyro.plate(...) # as in .quantiles() method
model = condition(vectorize(self.model), data)
trace = poutine.trace(model).get_trace(*args, **kwargs)
return {
name: site["value"]
for name, site in trace.nodes.items()
if site["type"] == "sample"
if not site_is_subsample(site)
}
def predict_median(...):
# ...similar...
Note there would be a couple difficulties inspecting deterministic sites by default:
For these reasons I think this issue deserves a new separate interface.
Currently, nodes that are wrapped in
determinstic
are not traced, and obtaining inferences from these nodes must be run through aPredictive
object. When a user deliberately creates these deterministics they are signalling that their values are of central interest, so it would be convenient to provide them immediately and not force users to execute additional steps.