Open eb8680 opened 3 months ago
@ordabayevy what do you think about this?
This makes sense to me. My two comments are:
PyroSample
the same as pyro.sample
and not inline it. But I don't use PyroSample
much and I can see that there might be the convenience of inlining it if used a lot.PyroSamplePlateScope
or done per individual PyroSample
(e.g. through infer={"ignored_plates": ...}
which would also work with pyro.sample
)? For example if you want self.loc
to be sampled inside of the data
plate and self.scale
sampled outside of the data
plate:class Model(pyro.nn.PyroModule):
@pyro.nn.PyroSample
def loc(self):
return pyro.distributions.Normal(0, 1)
@pyro.nn.PyroSample(infer={"ignored_plates": ["data"]}) # new syntax
def scale(self):
return pyro.distributions.LogNormal(0, 1)
def forward(self, x_obs):
with pyro.plate("data", x_obs.shape[0], dim=-1):
return pyro.sample("x", pyro.distributions.Normal(self.loc, self.scale), obs=x_obs) # self.loc is local and self.scale is global
(one drawback of this approach is that ignored_plates
is not in the forward
method and hidden elsewhere which can make it harder to read the code)
Problem
PyroModule
andPyroSample
make it straightforward to compositionally specify probabilistic models with random parameters. However,PyroSample
has a somewhat awkward interaction withpyro.plate
:To ensure
loc
andscale
are sampled globally, it is necessary to access them outside thedata
plate asscale
is in the above - inliningself.loc
in the final line samples a differentloc
for each datapoint. This behavior is unambiguous semantically, but it can cause confusion in more complex models and require lots of ugly boilerplate code in the model that manually samples random parameters of submodules in the correct plate context.For example, in the below code the intuitive behavior for
Model.linear
is clearly forlinear.weight
to be sampled outside of thedata
plate, but becauseself.linear
is invoked for the first time inside the plate, there will be separate random copies oflinear.weight
for each plate slice:However, it would not be correct to simply ignore all plates when executing
PyroSample
s - in this example, we might want to use a multi-sample ELBO estimator in inferringself.linear.weight
(e.g.pyro.infer.Trace_ELBO(num_particles=10, vectorize_particles=True)
), which is implemented with anotherplate
that should not be ignored.Proposed fix
It would be nice to have a feature that enabled the intuitive behavior in the second example above without breaking backwards compatibility with
PyroSample
's existing semantics or its correctness in the presence of enclosing plates like that introduced by the multi-sample ELBO.This could potentially be achieved with a new handler
PyroSamplePlateScope
such thatPyroSample
statements executed inside its context are only modified by plates entered outside of it, while ordinarypyro.sample
statements are unaffected and behave in the usual way: