Open nbro opened 4 years ago
@jvdillon who might be able to provide more context.
I don't think these should be different. I think the way this arose is that DenseVariational
is a more newly written layer that conforms with some of the other Distribution Layers, while the DenseReparameterization
and DenseLocalReparameterization
and friends were written further back so may not conform in the same way. The layer itself should use the reparameterization trick (as we do automatically in most TFP code).
The action items here would be to consolidate these list of layers, as well as maybe update some of the older layers (like DenseLocalReparameterization
).
@srvasude If you look at the source code of these classes, you will see that both DenseReparameterization
and DenseLocalReparameterization
derive from a base class called _DenseVariational
. They only override the method _apply_variational_kernel
.
I haven't yet looked at the details of _DenseVariational
and DenseVariational
.
@nbro did you ever figure out the difference between these two? I'm a little confused about them too.
@hartikainen DenseReparameterization
implements the forward pass by simply sampling from the posterior, without any reparametrization, i.e. it directly samples from the posterior distribution, and it uses this sample to compute the output of the layer. I don't know why it's called DenseReparameterization
, given that it's not really doing any re-parametrization trick, as far as I understand.
DenseLocalReparameterization
performs a slightly different sampling operation. You can find the details here.
In any case, this is the only difference between the two, i.e. how they sample the kernel during the forward pass (i.e. when call
is called).
Thanks @nbro, I think that makes sense! Do I understand correctly that the DenseVariational
(not _DenseVariational
) works the same way as DenseReparameterization
, as hinted by @srvasude above?
@hartikainen Ops, I had forgotten about DenseVariational
.
DenseVariational
is slightly different from the others, in terms of API and functionality. Here are the differences
DenseVariational
allows you to weight the KL divergence by specifying this weight as a parameter in the constructor of the class (i.e. __init__
), while all other Bayesian layers do not allow you to do this, unless you override the function that computes the KL divergence
DenseVariational
does not allow you to treat the bias in the same way as the kernel, which means that you can't really compute KL divergence between the prior and posterior biases. By default, DenseReparameterization
doesn't do this anyway, but you can do that.DenseVariational
, the function that builds the posterior and priors are not provided by default. This is really weirdIn general, it seems to me that DenseVariational
was created to fulfill the needs of 1-2 programmers that were tired of weighting the KL divergence by overriding the KL divergence function. I don't really understand why they simply didn't redesign all the APIs and improved the classes. They are quite limited and can definitely be improved.
Awesome, now it makes a lot more sense. Feels like these classes should be documented better but for now, hopefully others being confused can just find their way to your answers. Thanks again @nbro!
Does it mean that, in the case of
tfp.layers.DenseVariational
, you directly sample from the prior?
Looking at the code, they sample directly from the posterior, so probably it was only for explanation purposes.
Just an add-on. I think the tfp.layers.DenseVariational
was trying to assign uncertain weights like the Weight Uncertainty in Neural Networks. To write a pure Bayesian NN, we may refer to this tutorial.
Just found out both DenseFlipout
and DenseReparameterization
are the older layers API as the discussion in https://github.com/tensorflow/probability/issues/359
The (global) re-parametrization trick is described in the paper Auto-Encoding Variational Bayes. The local re-parametrization trick is described in Variational Dropout and the Local Reparameterization Trick. The flipout estimator is described in Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches.
The documentation of the
tfp.layers.DenseVariational
saysHow is this different than variational inference with the re-parametrization trick? Does it mean that, in the case of
tfp.layers.DenseVariational
, you directly sample from the prior? Shouldn't this create high variance during training? Moreover, shouldn't you actually sample from the posterior (rather than the prior)? Finally, what does "likelihood" inY ~ Likelihood(M)
mean?Can you point me to the paper or implementation that you based your implementation on?