tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

KL-reweighting not done correctly in BNN code ? #719

Open malharjajoo opened 4 years ago

malharjajoo commented 4 years ago

Hi,

It seems that there is some discrepancy in how the KL-reweighting is done during stochastic/batched training.

In section 3.4 of this paper: Weight Uncertainty in Neural Networks, the KL divergence is re-weighted using "number of batches", but in the BNN example code here, the KL divergence is re-weighted using size of batch

Does this seem like a bug ?

Thanks

nbro commented 4 years ago

@malharjajoo This question has been asked before in this issue tracker. See, for example, the issue I had opened: https://github.com/tensorflow/probability/issues/651. In the paper you mention, the authors actually mention more than one way of weighting the KL term of the ELBO.

KL divergence is re-weighted using size of batch

Actually, in that example, the KL divergence is divided by the total number of training examples (and not the batch size).

malharjajoo commented 4 years ago

@nbro, thanks for that. I can see that the KL re-weighting part is actually unclear, even for others.

Also, I think the KL is being divided by size of batch, because that code is building a batched data pipeline using tf.data API. and when I ran through the the line of code I mentioned earlier with a debugger, mnist_data.train.num_examplesturns out to be =128 (which is default batch size in their code).

nbro commented 4 years ago

because that code is building a batched data pipeline using tf.data API. and when I ran through the the line of code I mentioned earlier with a debugger, mnist_data.train.num_examples turns out to be =128 (which is default batch size in their code).

I haven't checked it (or maybe I checked it and I don't remember it anymore, given it was a few weeks ago), but are you sure? I recommend you double-check it because maybe you're looking at another variable. mnist_data.train.num_examples is a property of train, which is a member of the object mnist_data, which is returned by mnist.read_data_sets

https://github.com/tensorflow/probability/blob/563cdd7bb46e3ce0bd069c713c581e7409f26de1/tensorflow_probability/examples/bayesian_neural_network.py#L224

mnist_data.train.num_examples shouldn't be a placeholder or something that is filled during training or when the model is compiled (but I could be wrong, given that I haven't checked the implementation). What makes you think that the input pipeline is modifying mnist_data.train.num_examples? Anyway, I don't exclude the possibility that mnist_data.train.num_examples is modified under the hood for some reason I am not aware of.