tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.11k forks source link

Why does only DenseVariational provide the kl_weight parameter? #691

Open nbro opened 4 years ago

nbro commented 4 years ago

The DenseVariational provides the kl_weight parameter, whose documentation is

kl_weight: Amount by which to scale the KL divergence loss between prior and posterior.

In the paper "Weight Uncertainty in Neural Networks", equation 9

mini-batch-bbb-elbo

includes a scaling factor (pi) that weights the complexity cost (the KL term) relative to the likelihood cost on each mini-batch. The authors of the mentioned paper say that there are different ways of choosing this scaling factor pi. They show two ways: pi=1/M (uniform scaling) and pi=(2^{M-i})/(2^M - 1), where M is the number of mini-batches in an epoch and i is the ith mini-batch. Anyway, the scaling factor pi depends, in both cases, on the number of mini-batches.

The need for this scaling factor comes from the assumption that the KL term in the ELBO is calculated given all training instances, while, in mini-batch stochastic gradient descent, we do not use all training instances, but only of a subset of those. Therefore, given that the KL term is added to the loss at every step of gradient descent (so at every mini-batch), we need to scale the KL term proportional to the number of mini-batches. In the issue https://github.com/tensorflow/probability/issues/651, I wonder how we should scale the KL term, given that some people suggest to scale it by the total number of training instances.

Questions

  1. Does the parameter kl_weight of DenseVariational correspond to pi in equation 9 of the paper mentioned above? In other words, is it the number that is multiplied by the KL term, at each training iteration (i.e. mini-batch)? If not, then what does it represent?

  2. Why only DenseVariational provides a way of specifying kl_weight? Why don't other variational layers, such as tfp.layers.DenseReparameterization, tfp.layers.DenseLocalReparameterization, etc., also provide such a parameter? I know I can provide the function kernel_divergence_fn, so there's a possibility to also scale the KL divergence term in these cases, but why don't you provide the possibility to specify the parameter kl_weight, which I assume, in this second question, to correspond to pi in equation 9 of paper mentioned above?

bgroenks96 commented 4 years ago

Agreed. It's a bit clunky to have to redefine the KL divergence function to apply the weight.

nbro commented 4 years ago

Apparently, DenseVariational was written later than all other Bayesian layers (at least, according to one of the authors). See https://github.com/tensorflow/probability/issues/727#issuecomment-617710304 and https://github.com/tensorflow/probability/issues/409#issuecomment-492870964.