Closed ksachdeva closed 3 years ago
It is just the bias correction term for accounting for the minibatch. When small batch-sizes are used, it can lead to a large variance in the KLD value. But it should work without that kld_weight term too.
Related question: https://github.com/AntixK/PyTorch-VAE/issues/40
Hi @AntixK
Many thanks for this great effort.
Based on my understanding so far the original VAE does not talk about weighing the kl_divergence_loss. Later beta-vae and many other papers made the case of weighing the kl_div (and essentially treat it as a hyper-parameter).
In your implementations, I see that you consistently use
kld_weight = kwards['M_N'] = batch_size/num_of_images
.Is this a norm to select the weight for kl_div loss using the ratio of batch size and a number of images?
Since in the original VAE paper no weighing was done is it okay to use it in vanilla_vae.py?
Regards Kapil