tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

What is the difference between DenseVariational and DenseReparameterization? #727

Open nbro opened 4 years ago

nbro commented 4 years ago

The (global) re-parametrization trick is described in the paper Auto-Encoding Variational Bayes. The local re-parametrization trick is described in Variational Dropout and the Local Reparameterization Trick. The flipout estimator is described in Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches.

The documentation of the tfp.layers.DenseVariational says

This layer uses variational inference to fit a "surrogate" posterior to the distribution over both the kernel matrix and the bias terms

This layer fits the "weights posterior" according to the following generative process:

[K, b] ~ Prior()
M = matmul(X, K) + b
Y ~ Likelihood(M)

How is this different than variational inference with the re-parametrization trick? Does it mean that, in the case of tfp.layers.DenseVariational, you directly sample from the prior? Shouldn't this create high variance during training? Moreover, shouldn't you actually sample from the posterior (rather than the prior)? Finally, what does "likelihood" in Y ~ Likelihood(M) mean?

Can you point me to the paper or implementation that you based your implementation on?

srvasude commented 4 years ago

@jvdillon who might be able to provide more context.

I don't think these should be different. I think the way this arose is that DenseVariational is a more newly written layer that conforms with some of the other Distribution Layers, while the DenseReparameterization and DenseLocalReparameterization and friends were written further back so may not conform in the same way. The layer itself should use the reparameterization trick (as we do automatically in most TFP code).

The action items here would be to consolidate these list of layers, as well as maybe update some of the older layers (like DenseLocalReparameterization).

nbro commented 4 years ago

@srvasude If you look at the source code of these classes, you will see that both DenseReparameterization and DenseLocalReparameterization derive from a base class called _DenseVariational. They only override the method _apply_variational_kernel.

I haven't yet looked at the details of _DenseVariational and DenseVariational.

hartikainen commented 4 years ago

@nbro did you ever figure out the difference between these two? I'm a little confused about them too.

nbro commented 4 years ago

@hartikainen DenseReparameterization implements the forward pass by simply sampling from the posterior, without any reparametrization, i.e. it directly samples from the posterior distribution, and it uses this sample to compute the output of the layer. I don't know why it's called DenseReparameterization, given that it's not really doing any re-parametrization trick, as far as I understand.

DenseLocalReparameterization performs a slightly different sampling operation. You can find the details here.

In any case, this is the only difference between the two, i.e. how they sample the kernel during the forward pass (i.e. when call is called).

hartikainen commented 4 years ago

Thanks @nbro, I think that makes sense! Do I understand correctly that the DenseVariational (not _DenseVariational) works the same way as DenseReparameterization, as hinted by @srvasude above?

nbro commented 4 years ago

@hartikainen Ops, I had forgotten about DenseVariational.

DenseVariational is slightly different from the others, in terms of API and functionality. Here are the differences

In general, it seems to me that DenseVariational was created to fulfill the needs of 1-2 programmers that were tired of weighting the KL divergence by overriding the KL divergence function. I don't really understand why they simply didn't redesign all the APIs and improved the classes. They are quite limited and can definitely be improved.

hartikainen commented 4 years ago

Awesome, now it makes a lot more sense. Feels like these classes should be documented better but for now, hopefully others being confused can just find their way to your answers. Thanks again @nbro!

cserpell commented 4 years ago

Does it mean that, in the case of tfp.layers.DenseVariational, you directly sample from the prior?

Looking at the code, they sample directly from the posterior, so probably it was only for explanation purposes.

huckiyang commented 3 years ago

Just an add-on. I think the tfp.layers.DenseVariational was trying to assign uncertain weights like the Weight Uncertainty in Neural Networks. To write a pure Bayesian NN, we may refer to this tutorial.

Just found out both DenseFlipout and DenseReparameterization are the older layers API as the discussion in https://github.com/tensorflow/probability/issues/359