Open nbro opened 4 years ago
Agreed. It's a bit clunky to have to redefine the KL divergence function to apply the weight.
Apparently, DenseVariational
was written later than all other Bayesian layers (at least, according to one of the authors). See https://github.com/tensorflow/probability/issues/727#issuecomment-617710304 and https://github.com/tensorflow/probability/issues/409#issuecomment-492870964.
The
DenseVariational
provides thekl_weight
parameter, whose documentation isIn the paper "Weight Uncertainty in Neural Networks", equation 9
includes a scaling factor (
pi
) that weights the complexity cost (the KL term) relative to the likelihood cost on each mini-batch. The authors of the mentioned paper say that there are different ways of choosing this scaling factorpi
. They show two ways:pi=1/M
(uniform scaling) andpi=(2^{M-i})/(2^M - 1)
, whereM
is the number of mini-batches in an epoch andi
is thei
th mini-batch. Anyway, the scaling factorpi
depends, in both cases, on the number of mini-batches.The need for this scaling factor comes from the assumption that the KL term in the ELBO is calculated given all training instances, while, in mini-batch stochastic gradient descent, we do not use all training instances, but only of a subset of those. Therefore, given that the KL term is added to the loss at every step of gradient descent (so at every mini-batch), we need to scale the KL term proportional to the number of mini-batches. In the issue https://github.com/tensorflow/probability/issues/651, I wonder how we should scale the KL term, given that some people suggest to scale it by the total number of training instances.
Questions
Does the parameter
kl_weight
ofDenseVariational
correspond topi
in equation 9 of the paper mentioned above? In other words, is it the number that is multiplied by the KL term, at each training iteration (i.e. mini-batch)? If not, then what does it represent?Why only
DenseVariational
provides a way of specifyingkl_weight
? Why don't other variational layers, such astfp.layers.DenseReparameterization
,tfp.layers.DenseLocalReparameterization
, etc., also provide such a parameter? I know I can provide the functionkernel_divergence_fn
, so there's a possibility to also scale the KL divergence term in these cases, but why don't you provide the possibility to specify the parameterkl_weight
, which I assume, in this second question, to correspond topi
in equation 9 of paper mentioned above?