Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\SW\pyro-ppl\pyro\nn\module.py", line 527, in __call__
result = super().__call__(*args, **kwargs)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 759, in forward
self._setup_prototype(*args, **kwargs)
File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 875, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "C:\SW\pyro-ppl\pyro\infer\autoguide\guides.py", line 644, in _setup_prototype
biject_to(site["fn"].support).inv(site["value"]).shape
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 263, in __call__
return self._inv._inv_call(x)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 170, in _inv_call
return self._inverse(y)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 455, in _inverse
return self.base_transform.inv(y)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 263, in __call__
return self._inv._inv_call(x)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 170, in _inv_call
return self._inverse(y)
File "C:\ProgramData\Anaconda3\lib\site-packages\torch\distributions\transforms.py", line 1172, in _inverse
assert y.size(self.dim) == len(self.transforms)
AssertionError
The error is due to SplitReparam not creating the right support for the sites of the split reparameterization.
The Solution
Change the way SplitReparam figures out the support of slices of transformed distributions that are based on stacking or concatenation transforms.
The Problem
The
SplitReparam
reparameterization does not support transformed distributions that are based on stacking or concatenation transforms.Consider the code
which raises the error
The error is due to
SplitReparam
not creating the right support for the sites of the split reparameterization.The Solution
Change the way
SplitReparam
figures out the support of slices of transformed distributions that are based on stacking or concatenation transforms.