pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Support Distribution-valued observations? #988

Open fritzo opened 6 years ago

fritzo commented 6 years ago

What would it take to support Distribution-valued observations in pyro.sample() statements? The idea is to Rao-Blackwellize by interpreting observe statements with distribution-valued observations

pyro.sample("x", Normal(loc,scale), obs=Normal(0,1))

as inducing a log_prob term equal to -KL(obs, dist)

site["log_prob"] = -kl_divergence(Normal(0, 1), Normal(loc, scale))

This should be equivalent to but lower-variance than the current stochastic version

pyro.sample("y", Normal(0, 1))
pyro.sample("x", Normal(loc, scale), obs=y)

In this interpretation, we can view our usual tensor-valued observations as observations from Delta distributions:

pyro.sample("x", Normal(loc, scale), obs=Delta(y))
assert -kl_divergence(Delta(y), Normal(loc, scale)) == Normal(loc, scale).log_prob(y)

In case the distribution pair is missing from the analytic kl_divergence() table, we can always revert to a single Monte-Carlo sample as in the current implementation.

Note that ideas like this have been discussed before in the context of KL divergence #91 #688 ; this issue aims to support a limited form of kl_divergence() by extending the pyro.sample() primitive.

fritzo commented 6 years ago

@karalets Would this interface enable any of your use cases for analytic KLs in SVI?

martinjankowiak commented 6 years ago

hmm, this is interesting... though seems a bit strange/unidiomatic to me? how does this interact with trace/replay?

in any case, one still has the fundamental issue that the existence of a particular analytic kl isn't a sufficient condition to integrate out a given random variable. of course one could decide that that issue is of lesser importance

eb8680 commented 6 years ago

This is a little strange to me, especially outside the context of SVI, although that doesn't mean there's not a sound probabilistic interpretation - I guess it's like saying inference is computing a conditional expectation with respect to those random variables, rather than just a posterior distribution?

analytic KLs in SVI

Isn't the simplest/most obvious SVI use case that we were trying to enable (analytic KL between prior and guide) easier to expose by just allowing users to mark sites for analytic KLs in site["infer"], like we do for enumeration?

fritzo commented 6 years ago

Isn't the simplest/most obvious SVI use case ... easier to expose by just allowing users to mark sites?

That sounds reasonable. I'll leave this issue open for a week in case @karalets wants to comment, and then close it.

eb8680 commented 6 years ago

I'll leave this issue open for a week in case @karalets wants to comment, and then close it.

I think the use case you proposed is different and interesting enough that we can keep it open for more discussion. We can create a different issue for the simpler use case if you want.

fritzo commented 6 years ago

Another possible use case is to support the non-binarized image trick, where we could

predicted_image = my_nn_predictor(features)  # A [0,1]-valued image
with x_axis, y_axis:
    pyro.sample("obs", Bernoulli(predicted_image), obs=Bernoulli(true_image))
karalets commented 6 years ago

Yeah I like this issue actually.

Generally, this is a different way of looking at loss functions.

For observations X and latents Z, currently SVI minimizes KL(q(z|x)||p(z|x)). However, with distribution valued observations one can do the following: minimize KL(q(x,z)||p(x,z)) , which is an alternative to traditional Elbo with many cute properties that I used for some GAN stuff once upon a time but is also useful apart from any GAN stuff.

The math works out as follows: In model, we do not observe X but just sample it. In guide, we do not just condition on X but actually sample from it using an empirical distribution.

In model then, the samples X's from guide will be scored as if they were latent.

This is quite a principled and easy way to implement what Fritz seems be interested in with the current machinery.

No analytic KL's needed, actually.