pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.58k stars 987 forks source link

Make Predictive work with the SplitReparam reparameterizer [bugfix] #3388

Closed BenZickel closed 3 months ago

BenZickel commented 3 months ago

Problem

Using pyro.infer.Predictive with a model that utilizes the pyro.infer.reparam.SplitReparam reparameterizer raises an error as pyro.infer.Predictive tries to sample from the model in order to determine site shapes.

Straightforward sampling of a model that utilizes the pyro.infer.reparam.SplitReparam reparameterizer is not possible as this reparameterizer introduces sites with the pyro.distributions.ImproperUniform distribution, which does not support sampling.

Solution

Wrap the model with the pyro.poutine.InitMessenger effect handler during the site shapes determination phase. This solves the problem as the pyro.poutine.InitMessenger effect handler assigns values to the pyro.distributions.ImproperUniform sites before they are sampled.

This is a specific feature of pyro.poutine.ReparamMessenger, which applies initialization by pyro.poutine.InitMessenger before sampling, even if it appears last in the messenger stack (see #2876).

Testing

The fix can be verified by running

pytest tests\infer\reparam\test_split.py::test_predictive