Closed ndgAtUber closed 7 years ago
For context, this issue arose in a discussion with @ndgAtUber about the type signature of marginal inference (see also #6 ): should Marginal(...)
return a callable that constructs a distribution object, or a callable that returns samples from the marginal distribution?
Focusing on the implementation and behavior of primitive distributions, maybe we could make both options available as in webPPL's distribution library, and implement the distribution classes in terms of the sample and score functions while exposing both APIs to the user. This feels more PyTorch-y to me, since distribution objects could handle resizing for different batches and caching and bookkeeping under the hood, optionally do operations in-place to use memory more efficiently, and be reused more naturally inside nn.Module
s independently of Pyro. This is a common pattern in PyTorch; cf torch.nn.functional
and torch.nn
.
Now, jumping up a level of abstraction, there are two different ways to conceptualize pyro.sample
. The first, which I'll call the webPPL way, is to think of sample
as a utility for registering sample
method calls of a fixed set of primitive distributions with an inference engine. The second, which I'll call the "Venture" way (quotes since this isn't really an accurate description of how things work in Venture, but which came from my Venture reading odyssey), is to think of pyro.sample
as apply(fn, ..)
for arbitrary stochastic functions, plus side effects added dynamically at runtime by inference engines, so that it has the signature val = pyro.sample("name", fn, *args, **kwargs)
.
In that case, val = pyro.sample("name", binomial, p)
and val = pyro.sample("name", Binomial(p))
are semantically equivalent, because in Pyro all primitive distributions currently have their __call__
operator aliased to their sample
method. val = binominal(p)
would not be an equivalent replacement - it would get ignored by an inference engine since it's not wrapped in a pyro.sample
and there can't be a pyro.sample
inside it without some sort of automatic name generation, which we've decided against for the moment. observe
has a natural counterpart here as well.
I think what really happened in our original discussion is that I tripped over a design question for which I don't have a fully-formed answer: if we choose the Venture way of conceptualizing pyro.sample
, what should happen when we call pyro.sample("name", fn, ...)
and fn
is a callable that itself has pyro.sample
statements inside? A natural scenario where this happens is sampling from a marginal distribution produced by an inference engine with auxiliary variables (e.g. SMC).
Is this a problem worth solving, or have I posed it incorrectly? What are the arguments for and against the Venture way? One argument in favor: doing things that way means that we can write models that (at least attempt to) observe
arbitrary stochastic functions, which seems like an extremely powerful and natural idiom. One argument against: the webPPL way makes some inference optimizations easier (e.g. we always know how to score samples).
I'm coming around to the idea that distributions should be able to directly take arguments (so they are really stochastic functions that also have a scoring method). This is will require upgrading pyro.sample and pyro.observe, but i don't think that's a big deal.
It is coherent to get params into distributions either in a constructor (pyro.sample("name", Binomial(p))
) or directly (pyro.sample("name", binomial, p)
). My preference would be that by convention we use constructors only for distributions that maintain internal state (eg CRP, GP).
This change requires:
[x] upgrade pyro.sample and pyro.observe to pass args on to the distribution sampler and scorer methods. this will require adjusting all current poutines / inference algorithms.
[x] change stateless primitive distributions (ie all current ones) to take params at call time, not construct time.
[x] provide pyro.distribution.* functions for the stateless distributions? (so that we don't need lot's of binomial = Binomial()
lines.)
For the record my thinking is based partly on considering the kinds of objects we have in pyro and their relations. Here are some different kinds of things:
primitive sampler: a function that returns a random value, whose randomness is not registered via pyro.sample
. these must come together with a scoring function, otherwise inference methods can't properly account for the randomness.
stochastic function: a callable that includes deterministic python and pyro.sample. these are normalized, so we can sample from them, but "implicit" in the sense that we don't know the marginal probability of a return value a priori.
scoring method: reports the log_pdf of a return value from a stochastic function.
distribution: a stochastic function or primitive sampler with a scoring method. these can be either a primitive sampler with a scorer, or a pyro stochastic function that has acquired a scorer eg by marginalization.
unnormalized stochastic function: a callable that includes deterministic python, pyro.sample
, and pyro.observe
. you can't do anything with these directly. inference methods normalize them into trace posteriors, marginal methods make them into either stochastic functions (no scorer) or distributions.
under this view Marginal(...)
returns a distribution (function with sampler and scorer, that takes same args as original model), as @eb8680 suggested. interestingly there may be cases where it is enough to construct an implicit marginal, which is a stochastic function. (especially since smoothing the pdf from samples is an open question.)
what should happen when we call pyro.sample("name", fn, ...) and fn is a callable that itself has pyro.sample statements inside?
my view is that the poutine handling this sample call is responsible for deciding whether it wants to treat fn
as a primitive (using the score method to account for it's randomness) or as a compound stochastic function (exposing the random choices within a sample).
doing things that way means that we can write models that (at least attempt to) observe arbitrary stochastic functions, which seems like an extremely powerful and natural idiom.
i don't think we should allow this in general, though it is nice to be able to experiment with likelihood-free and ABC techniques.
per (pleasant, unheated) discussion with @eb8680, we could consider upgrading distribution objects to be callables (i.e. functions) that return a sample, and also have a scoring function attached. in particular, we could allow them to have args:
there are reasons to like this: it simplifies distributions (don't need to construct and then use), and makes it more clear that stateful dists (eg CRP) are the ones that require a constructor. (presumably if we go this route, we will change the
Binomial
class to a providedbinomial
function.)on the other hand, it requires making the signature to
pyro.sample
andpyro.observe
more complex: they must take the fn args as input.are there other complications to worry about?
if we do make this change, should a distribution automatically call pyro.sample when it is called in the context of a pyro program? i.e. should
pyro.sample(binomial, args)
be writable asbinomial(args)
?