pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.5k stars 982 forks source link

Error when using `infer_discrete` for a model with local variable that depends on discrete variable #2860

Open ordabayevy opened 3 years ago

ordabayevy commented 3 years ago

Issue Description

I'm trying to use infer_discrete for a model (code snippet below) which has local latent variable (locs) that depends on discrete variable (assignment). This leads to an error where site["log_prob"].shape and dim_to_symbol don't match:

  File "/home/ordabayev/repos/pyro/pyro/poutine/trace_struct.py", line 376, in pack_tensors
    packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
  File "/home/ordabayev/repos/pyro/pyro/ops/packed.py", line 29, in pack
    raise ValueError('\n  '.join([
ValueError: Error while packing tensors at site 'locs':
  Invalid tensor shape.
  Allowed dims: -1
  Actual shape: (2, 5)

Using pdb:

-> packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
(Pdb) p site["log_prob"].shape
torch.Size([2, 5])
(Pdb) p dim_to_symbol
{-1: 'a'}

Environment

Code Snippet

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

data = torch.tensor([0., 1., 10., 11., 12.])

K = 2  # Fixed number of components.

@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    clusters = torch.tensor([0., 10.])

    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        locs = pyro.sample('locs', dist.Normal(clusters[assignment], 2.))
        pyro.sample('obs', dist.Normal(locs, scale), obs=data)

def guide(data):
    # Global variables.
    pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    pyro.sample('scale', dist.LogNormal(0., 2.))

    with pyro.plate('data', len(data)):
        # Local variables.
        pyro.sample('locs', dist.Normal(10., 2.))

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)

guide_trace = poutine.trace(guide).get_trace(data)  # record the globals
trained_model = poutine.replay(model, trace=guide_trace)  # replay the globals

inferred_model = infer_discrete(trained_model, temperature=1,
                                first_available_dim=-2)  # avoid conflict with data plate
trace = poutine.trace(inferred_model).get_trace(data)
print(trace.nodes["assignment"]["value"])
ordabayevy commented 3 years ago

To construct dim_to_symbol that matches the shape of site["log_prbo"] I'm using the following hack but not sure if this is a correct solution or might even break the code in other places:

class EnumMessenger(Messenger):
    ...
    def _pyro_sample(self, msg):        """
         :param msg: current message at a trace site.
         :returns: a sample from the stochastic function at the site.
         """
+        param_dims = _ENUM_ALLOCATOR.dim_to_id.copy()  # enum dim -> unique id
+        self._param_dims[msg["name"]] = param_dims
         if msg["done"] or not isinstance(msg["fn"], TorchDistributionMixin):
             return

         # Compute upstream dims in scope; these are unsafe to use for this site's target_dim.
         scope = msg["infer"].get("_markov_scope")  # site name -> markov depth
-        param_dims = _ENUM_ALLOCATOR.dim_to_id.copy()  # enum dim -> unique id
         if scope is not None:
             for name, depth in scope.items():
                 if self._markov_depths[name] == depth:  # hide sites whose markov context has exited
                     param_dims.update(self._value_dims[name])
             self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"]
-        self._param_dims[msg["name"]] = param_dims
         if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel":
             return