pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.49k stars 982 forks source link

Subsampling in some autoguides produces parameters with wrong shapes #3286

Open gui11aume opened 10 months ago

gui11aume commented 10 months ago

Issue Description

Auto guides need to create parameters in the background. The shape of those parameters is determined by the plates in the model. When plates are subsampled, the parameters should have the dimension of the full plate, not the subsampled plate. This is the case for some auto guides, but for AutoDiscreteParallel the shape of the parameters is wrong.

Environment

Code Snippet

The code below shows the difference in behavior between AutoNormal and AutoDiscreteParallel. In both cases, the model creates a plate of size 20 and subsamples it to size 3. Upon gathering the parameters, AutoNormal produces parameters with 20 rows, whereas AutoDiscreteParallel produces parameters with 3 rows.

import pyro
import torch

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Normal(0, 1)) 

guide = pyro.infer.autoguide.AutoNormal(model)

elbo = pyro.infer.Trace_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)

print(pyro.param("AutoNormal.locs.x").shape)
# torch.Size([20])

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Categorical(torch.ones(1)))

guide = pyro.infer.autoguide.AutoDiscreteParallel(model)

elbo = pyro.infer.TraceEnum_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)

print(pyro.param("AutoDiscreteParallel.x_probs").shape)
# torch.Size([3, 1])

I believe that the issue is in the functions _setup_prototype in pyro/infer/autoguide/guides.py. Below is the code from AutoNormal (see here).

            # If subsampling, repeat init_value to full size.
            for frame in site["cond_indep_stack"]:
                full_size = getattr(frame, "full_size", frame.size)
                if full_size != frame.size:
                    dim = frame.dim - event_dim
                    init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
            init_scale = torch.full_like(init_loc, self._init_scale)

There is no equivalent in the _setup_prototype function of AutoDiscreteParallel (see here).

I will work on a pull request to fix this. I would like to also create some additional tests for this and other cases, but I am not too sure where to start. Any help would be appreciated.

gui11aume commented 10 months ago

Hi @fritzo! I had a closer look at the issue and it's a little more complicated than I thought... Are there already some tests for the creation of parameters in auto guides?

gui11aume commented 10 months ago

I think that the same phenomenon happens for AutoLowRankMultivariateNormal.

import pyro
import torch

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Normal(0, 1)) 

guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model, rank=2)

elbo = pyro.infer.Trace_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)

print(pyro.param("AutoLowRankMultivariateNormal.loc").shape)
# torch.Size([3])
print(pyro.param("AutoLowRankMultivariateNormal.scale").shape)
# torch.Size([3])
print(pyro.param("AutoLowRankMultivariateNormal.cov_factor").shape)
# torch.Size([3,2])

The parameters should have 20 rows but they have 3.

Following the suggestion of the doc, we can initialize the parameters with pyro.param(...) before calling the guide, hoping to get the correct dimensions. However this fails because Pyro expects the number of rows to be 3 (if you initialize the parameters with 3 rows the code runs fine).

import pyro
import torch

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Normal(0, 1)) 

pyro.param("AutoLowRankMultivariateNormal.loc", torch.zeros(20))
pyro.param("AutoLowRankMultivariateNormal.scale", torch.ones(20))
pyro.param("AutoLowRankMultivariateNormal.cov_factor", torch.ones(20,2))

guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(model, rank=2)

elbo = pyro.infer.Trace_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)
# ...
# AssertionError
martinjankowiak commented 10 months ago

AutoGuides + data subsampling requires using create_plates see e.g. this test

gui11aume commented 10 months ago

Thanks @martinjankowiak! I have tried creating the plates manually in different contexts, but I did not get any luck. Have a look at the example below: am I doing it wrong?

import pyro
import torch

def model():
   with pyro.plate("dummy", 20, subsample_size=3):
      pyro.sample("x", pyro.distributions.Categorical(torch.ones(1)))

def create_plate_x():
   return pyro.plate("dummy", 20, subsample_size=3, dim=-1)

guide = pyro.infer.autoguide.AutoDiscreteParallel(model, create_plates=create_plate_x)

elbo = pyro.infer.TraceEnum_ELBO()

with pyro.poutine.trace(param_only=True) as param_capture:
   elbo.differentiable_loss(model, guide)

print(pyro.param("AutoDiscreteParallel.x_probs").shape)
# torch.Size([3, 1])

Thanks for the link to the test! It seems to run with AutoDetla and AutoNormal but I never had problems with those; I think they work fine.