tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

Is it possible to have a mixture of distributions as the prior of a Bayesian layer? #816

Open nbro opened 4 years ago

nbro commented 4 years ago

The paper Weight Uncertainty in Neural Networks proposes the usage of a mixture of two Gaussians as the prior distribution. By default, the priors (of the kernel and bias) of a Convolution2DFlipout layer are Gaussians. Is there an easy way of having a mixture of Gaussians as the priors?

I know TFP provides the class MixtureSameFamily. For example, you can create a mixture of two Gaussians in the following way.

from tensorflow_probability import distributions as tfd

prior = tfd.MixtureSameFamily(tfd.Categorical(probs=[0.4, 0.6]),
                              components_distribution=tfd.Normal(
                                  loc=[-1., 1],  # One for each component.
                                  scale=[0.1, 0.5])
                              )

In order to have a mixture of two Gaussians as the prior of a Bayesian layer, there must be a way of computing the KL divergence between a mixture of two Gaussians and another Gaussian (the posterior). After having had looked at the documentation of tfp.distributions.kl_divergence, apparently, this is not yet implemented.

nbro commented 4 years ago

@brianwa84, @jvdillon, @davmre, @jburnim, @SiegeLordEx, @csuter, Any chance that this question will be answered? Or is this already implemented and I missed it?

If you want some inspiration, equations 16 and 17 of the paper Keeping Neural Networks Simple by Minimizing the Description Length of the Weights actually show how to compute the KL divergence between a mixture of Gaussians prior and a Gaussian posterior.

Would it be possible for you to provide an example of how this would be easily implemented in TensorFlow Probability?

jvdillon commented 4 years ago

Could you use the monte carlo approximation of kl divergence?

As for exact computation, I am not aware of a closed form calculation for this. You could maybe try bounding or approximating? https://mathoverflow.net/a/308022/10590

nbro commented 4 years ago

@jvdillon Thanks for answering!

How would the MC approximation of the KL divergence work? Do you have a concrete example (i.e. code)?

The other papers seem promising from the title, but I don't have the time now to look at them.

jvdillon commented 4 years ago

Something like:

z = q.sample(T)
reduce_mean(q.log_prob(z) - likelihood(z).log_prob(y) - prior.log_prob(z), axis=-1)