pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Should poutine.block block @broadcast and @independent? #2290

Open fritzo opened 4 years ago

fritzo commented 4 years ago

Nested inference #1227 requires poutine.blocking of nested inference, but these nested inference algorithms should arguably respect outermost poutine.broadcast and possibly poutine.independent. (Note the non-broadcasted TMC versions are also interesting, but are not the only- and not the simplest- way to nest inference.)

The ConjugateReparam #2289 needs to sample an internal variable in a broadcast-aware way, but it appears the math in SVI and HMC is correct when this internal variable is removed from the trace (at least for reparameterized variables; the non-reparameterized case is more subtle and may require pyro.factor with DiCE factors; TBD). I believe this _do_not_trace hack could be removed if poutine.block did not block broadcast and independent. Indeed ConjugateReparam can be seen as nesting exact inference inside SVI, similar to #1227.

poutine.do #2157 now uses site["stop"] = True (like poutine.block), thereby blocking broadcasting of the intervened site. Is the current implementation correct in the presence of multiple-sample inference algorithms, e.g. where the intervention depends on an upstream sampled variable?

Perhaps poutine.block should merely block tracing? Are there other cases for a more limited poutine.block_tracing?

eb8680 commented 4 years ago

Perhaps poutine.block should merely block tracing? Are there other cases for a more limited poutine.block_tracing?

I don't see any reason to change the behavior of block right now, since it's used for things that aren't just nesting inference, but I guess we could add a separate block_tracing that has the behavior you're looking for. The special case of block_tracing added in #2289 is specific to auxiliary variables introduced by inference algorithms, which are not expected to be observed or replayed, but a generic block_tracing would interact with replay, condition, do and other effect handlers. Since those issues are rather abstract, I would prefer to start from a more general or unified design for nested inference and add or change effect handler behavior to support that.

poutine.do #2157 now uses site["stop"] = True (like poutine.block), thereby blocking broadcasting of the intervened site. Is the current implementation correct in the presence of multiple-sample inference algorithms, e.g. where the intervention depends on an upstream sampled variable?

poutine.do currently only supports setting the site's value to a constant, which doesn't need to be broadcasted, and adds an auxiliary site that behaves in the usual broadcasting-aware way. As with the nested inference case, I would prefer to start from a more general design for causal and counterfactual inference under a larger class of interventions and add or change behavior to support that. It's certainly true that some changes will be necessary - for example, as you say, the current use of stop would be incorrect for stochastic interventions.