Open gui11aume opened 10 months ago
Hi @fritzo! I had a closer look at the issue and it's a little more complicated than I thought... Are there already some tests for the creation of parameters in auto guides?
I think that the same phenomenon happens for AutoLowRankMultivariateNormal
.
import pyro
import torch
def model():
with pyro.plate("dummy", 20, subsample_size=3):
pyro.sample("x", pyro.distributions.Normal(0, 1))
guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model, rank=2)
elbo = pyro.infer.Trace_ELBO()
with pyro.poutine.trace(param_only=True) as param_capture:
elbo.differentiable_loss(model, guide)
print(pyro.param("AutoLowRankMultivariateNormal.loc").shape)
# torch.Size([3])
print(pyro.param("AutoLowRankMultivariateNormal.scale").shape)
# torch.Size([3])
print(pyro.param("AutoLowRankMultivariateNormal.cov_factor").shape)
# torch.Size([3,2])
The parameters should have 20 rows but they have 3.
Following the suggestion of the doc, we can initialize the parameters with pyro.param(...)
before calling the guide, hoping to get the correct dimensions. However this fails because Pyro expects the number of rows to be 3 (if you initialize the parameters with 3 rows the code runs fine).
import pyro
import torch
def model():
with pyro.plate("dummy", 20, subsample_size=3):
pyro.sample("x", pyro.distributions.Normal(0, 1))
pyro.param("AutoLowRankMultivariateNormal.loc", torch.zeros(20))
pyro.param("AutoLowRankMultivariateNormal.scale", torch.ones(20))
pyro.param("AutoLowRankMultivariateNormal.cov_factor", torch.ones(20,2))
guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model, rank=2)
elbo = pyro.infer.Trace_ELBO()
with pyro.poutine.trace(param_only=True) as param_capture:
elbo.differentiable_loss(model, guide)
# ...
# AssertionError
AutoGuides + data subsampling requires using create_plates
see e.g. this test
Thanks @martinjankowiak! I have tried creating the plates manually in different contexts, but I did not get any luck. Have a look at the example below: am I doing it wrong?
import pyro
import torch
def model():
with pyro.plate("dummy", 20, subsample_size=3):
pyro.sample("x", pyro.distributions.Categorical(torch.ones(1)))
def create_plate_x():
return pyro.plate("dummy", 20, subsample_size=3, dim=-1)
guide = pyro.infer.autoguide.AutoDiscreteParallel(model, create_plates=create_plate_x)
elbo = pyro.infer.TraceEnum_ELBO()
with pyro.poutine.trace(param_only=True) as param_capture:
elbo.differentiable_loss(model, guide)
print(pyro.param("AutoDiscreteParallel.x_probs").shape)
# torch.Size([3, 1])
Thanks for the link to the test! It seems to run with AutoDetla
and AutoNormal
but I never had problems with those; I think they work fine.
Issue Description
Auto guides need to create parameters in the background. The shape of those parameters is determined by the plates in the model. When plates are subsampled, the parameters should have the dimension of the full plate, not the subsampled plate. This is the case for some auto guides, but for
AutoDiscreteParallel
the shape of the parameters is wrong.Environment
Code Snippet
The code below shows the difference in behavior between
AutoNormal
andAutoDiscreteParallel
. In both cases, the model creates a plate of size 20 and subsamples it to size 3. Upon gathering the parameters,AutoNormal
produces parameters with 20 rows, whereasAutoDiscreteParallel
produces parameters with 3 rows.I believe that the issue is in the functions
_setup_prototype
inpyro/infer/autoguide/guides.py
. Below is the code fromAutoNormal
(see here).There is no equivalent in the
_setup_prototype
function ofAutoDiscreteParallel
(see here).I will work on a pull request to fix this. I would like to also create some additional tests for this and other cases, but I am not too sure where to start. Any help would be appreciated.