pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 987 forks source link

Support for transformed distributions, based on stacking or concatenation transforms, in SplitReparam #3390

Closed BenZickel closed 3 months ago

BenZickel commented 3 months ago

The Problem

The SplitReparam reparameterization does not support transformed distributions that are based on stacking or concatenation transforms.

Consider the code

import pyro
import pyro.distributions as dist
from pyro.infer.reparam import SplitReparam

import torch

batch_shape = (6, 5)
num_samples = 10

transform = dist.transforms.StackTransform([
        dist.transforms.OrderedTransform(),
        dist.transforms.DiscreteCosineTransform(),
        dist.transforms.HaarTransform()], dim=-1)

num_transforms = len(transform.transforms)

def model():
    scale_tril = pyro.sample("scale_tril", dist.LKJCholesky(num_transforms, 1))
    with pyro.plate_stack("plates", batch_shape):
        x_dist = dist.TransformedDistribution(
            dist.MultivariateNormal(
                torch.zeros(num_samples, num_transforms), scale_tril=scale_tril
            ).to_event(1),
            [transform])
        return pyro.sample("x", x_dist)

split_model = pyro.poutine.reparam(model, config={"x": SplitReparam([2, 1], -1)})

pyro.clear_param_store()
guide = pyro.infer.autoguide.AutoMultivariateNormal(split_model)
guide_sites = guide()

which raises the error

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\SW\pyro-ppl\pyro\nn\module.py", line 527, in __call__
    result = super().__call__(*args, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl  
    return forward_call(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 759, in forward
    self._setup_prototype(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 875, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 644, in _setup_prototype
    biject_to(site["fn"].support).inv(site["value"]).shape
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 263, in __call__
    return self._inv._inv_call(x)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 170, in _inv_call
    return self._inverse(y)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 455, in _inverse
    return self.base_transform.inv(y)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 263, in __call__
    return self._inv._inv_call(x)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 170, in _inv_call
    return self._inverse(y)
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 1172, in _inverse
    assert y.size(self.dim) == len(self.transforms)
AssertionError

The error is due to SplitReparam not creating the right support for the sites of the split reparameterization.

The Solution

Change the way SplitReparam figures out the support of slices of transformed distributions that are based on stacking or concatenation transforms.