pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Utilities for simplifying interactions between PyroSample and plates #3385

Open eb8680 opened 3 months ago

eb8680 commented 3 months ago

Problem

PyroModule and PyroSample make it straightforward to compositionally specify probabilistic models with random parameters. However, PyroSample has a somewhat awkward interaction with pyro.plate:

class Model(pyro.nn.PyroModule):

  @pyro.nn.PyroSample
  def loc(self):
    return pyro.distributions.Normal(0, 1)

  @pyro.nn.PyroSample
  def scale(self):
    return pyro.distributions.LogNormal(0, 1)

  def forward(self, x_obs):
    assert self.scale.shape == ()  # accessing self.scale triggers pyro.sample outside the plate
    with pyro.plate("data", x_obs.shape[0], dim=-1):
      assert self.loc.shape == (x_obs.shape[0],)  # accessing self.loc here triggers pyro.sample inside the plate
      return pyro.sample("x", pyro.distributions.Normal(self.loc, self.scale), obs=x_obs)

To ensure loc and scale are sampled globally, it is necessary to access them outside the data plate as scale is in the above - inlining self.loc in the final line samples a different loc for each datapoint. This behavior is unambiguous semantically, but it can cause confusion in more complex models and require lots of ugly boilerplate code in the model that manually samples random parameters of submodules in the correct plate context.

For example, in the below code the intuitive behavior for Model.linear is clearly for linear.weight to be sampled outside of the data plate, but because self.linear is invoked for the first time inside the plate, there will be separate random copies of linear.weight for each plate slice:

class BayesianLinear(pyro.nn.PyroModule[torch.nn.Linear]):

  @pyro.nn.PyroSample
  def weight(self):
    return dist.Normal(0, 1).expand([self.num_input, self.num_output]).to_event(2)

class Model(pyro.nn.PyroModule):
  def __init__(self, num_inputs, num_outputs):
    super().__init__()
    self.linear = BayesianLinear(num_inputs, num_outputs)

  def forward(self, x):
    with pyro.plate("data", x.shape[-2], dim=-1):
      loc = self.linear(x)
      assert self.linear.weight.shape[-3] == x.shape[-2]
      return pyro.sample("y", dist.Normal(loc, 1))

However, it would not be correct to simply ignore all plates when executing PyroSamples - in this example, we might want to use a multi-sample ELBO estimator in inferring self.linear.weight (e.g. pyro.infer.Trace_ELBO(num_particles=10, vectorize_particles=True)), which is implemented with another plate that should not be ignored.

Proposed fix

It would be nice to have a feature that enabled the intuitive behavior in the second example above without breaking backwards compatibility with PyroSample's existing semantics or its correctness in the presence of enclosing plates like that introduced by the multi-sample ELBO.

This could potentially be achieved with a new handler PyroSamplePlateScope such that PyroSample statements executed inside its context are only modified by plates entered outside of it, while ordinary pyro.sample statements are unaffected and behave in the usual way:

class Model(pyro.nn.PyroModule):
  def __init__(self, num_inputs, num_outputs):
    super().__init__()
    self.num_inputs = num_inputs
    self.num_outputs = num_outputs
    self.linear = BayesianLinear(num_inputs, num_outputs)

  @pyro.nn.PyroSample
  def scale(self):
    return pyro.distributions.LogNormal(0, 1).expand([self.num_outputs]).to_event(1)

  @PyroSamplePlateScope()
  def forward(self, x):
    with pyro.plate("data", x.shape[-2], dim=-1):
      loc = self.linear(x)
      assert self.linear.weight.shape[-3] == 1  # sampled outside data plate
      assert self.scale.shape[-2] == 1  # sampled outside data plate
      y = pyro.sample("y", dist.Normal(loc, self.scale).to_event(1))
      assert y.shape[-2] == x.shape[-2]  # ordinary pyro.sample statement
      return y
eb8680 commented 3 months ago

@ordabayevy what do you think about this?

ordabayevy commented 3 months ago

This makes sense to me. My two comments are:

  1. Personally, for me it is more intuitive to treat PyroSample the same as pyro.sample and not inline it. But I don't use PyroSample much and I can see that there might be the convenience of inlining it if used a lot.
  2. As a consideration, should plate scoping be implemented as a context manager like in PyroSamplePlateScope or done per individual PyroSample (e.g. through infer={"ignored_plates": ...} which would also work with pyro.sample)? For example if you want self.loc to be sampled inside of the data plate and self.scale sampled outside of the data plate:
class Model(pyro.nn.PyroModule):

  @pyro.nn.PyroSample
  def loc(self):
    return pyro.distributions.Normal(0, 1)

  @pyro.nn.PyroSample(infer={"ignored_plates": ["data"]})  # new syntax
  def scale(self):
    return pyro.distributions.LogNormal(0, 1)

  def forward(self, x_obs):
    with pyro.plate("data", x_obs.shape[0], dim=-1):
      return pyro.sample("x", pyro.distributions.Normal(self.loc, self.scale), obs=x_obs)  # self.loc is local and self.scale is global

(one drawback of this approach is that ignored_plates is not in the forward method and hidden elsewhere which can make it harder to read the code)