Open malharjajoo opened 4 years ago
@malharjajoo This question has been asked before in this issue tracker. See, for example, the issue I had opened: https://github.com/tensorflow/probability/issues/651. In the paper you mention, the authors actually mention more than one way of weighting the KL term of the ELBO.
KL divergence is re-weighted using size of batch
Actually, in that example, the KL divergence is divided by the total number of training examples (and not the batch size).
@nbro, thanks for that. I can see that the KL re-weighting part is actually unclear, even for others.
Also, I think the KL is being divided by size of batch, because that code is building a batched data pipeline using tf.data API. and when I ran through the the line of code I mentioned earlier with a debugger, mnist_data.train.num_examples
turns out to be =128 (which is default batch size in their code).
because that code is building a batched data pipeline using tf.data API. and when I ran through the the line of code I mentioned earlier with a debugger,
mnist_data.train.num_examples
turns out to be =128 (which is default batch size in their code).
I haven't checked it (or maybe I checked it and I don't remember it anymore, given it was a few weeks ago), but are you sure? I recommend you double-check it because maybe you're looking at another variable. mnist_data.train.num_examples
is a property of train
, which is a member of the object mnist_data
, which is returned by mnist.read_data_sets
mnist_data.train.num_examples
shouldn't be a placeholder or something that is filled during training or when the model is compiled (but I could be wrong, given that I haven't checked the implementation). What makes you think that the input pipeline is modifying mnist_data.train.num_examples
? Anyway, I don't exclude the possibility that mnist_data.train.num_examples
is modified under the hood for some reason I am not aware of.
Hi,
It seems that there is some discrepancy in how the KL-reweighting is done during stochastic/batched training.
In section 3.4 of this paper: Weight Uncertainty in Neural Networks, the KL divergence is re-weighted using "number of batches", but in the BNN example code here, the KL divergence is re-weighted using size of batch
Does this seem like a bug ?
Thanks