Closed aishgupta closed 4 years ago
I had a few questions about this too. Is the 0.5 there because the out batch size is 2x the in batch size?
We take an average, so that loss term doesn't depend on the batch size. The weight 0.5 is just the hyperparameter we chose early in experimentation and we stuck with it.
Ah right, that makes sense. Thanks!
Hi,
I am not able to understand the loss function mentioned in the file oe_scratch.py in MNIST folder. It does not look like minimising the KL divergence between Uniform distribution and cross-entropy distribution.
The loss mentioned in the file is: loss += 0.5 * -(x[len(in_set[0]):].mean(1) - torch.logsumexp(x[len(in_set[0]):], dim=1)).mean()
Can you please help me in understanding what it is trying to do.