AntixK / PyTorch-VAE

A Collection of Variational Autoencoders (VAE) in PyTorch.
Apache License 2.0
6.52k stars 1.06k forks source link

KLD Weight #56

Open abyildirim opened 2 years ago

abyildirim commented 2 years ago

Hi,

In the VAE paper (https://arxiv.org/pdf/1312.6114.pdf), the VAE loss function has no additional weight parameter for the KLD loss:

vae_loss

However, in the implementation of the Vanilla VAE model, the loss function is written as below:

loss = recons_loss + kld_weight * kld_loss

When I set "kld_weight" to 1 in my model, it could not learn how to reconstruct the images. If I understand correctly, the "kld_weight" reduces the effect of the KLD loss to balance it with the reconstruction loss. However, as I mentioned, it is not defined in the VAE paper. Could anyone please explain to me why this parameter is used and why it is set to 0.00025 by default?

wonjunior commented 2 years ago

It is defined in equation 8 of the paper.

abyildirim commented 2 years ago

In Equation 8, I see that the MSE loss is also scaled with N/M. However, only the KLD loss is scaled in the code. Shouldn't we scale both of them according to the equation @WonJunior ?

image image

dorazhang93 commented 2 years ago

Hi, I was also confused about the _kldweight here. But I think I found the proper interpretation in this paper, Beta-VAE. image Given the reconstructed loss is averaged on each pixel and kld loss averaged on each latent dimension, the M here is the dimensionality of z and N is the dimensionality of input (for images, W*H). And in this implementation, kld loss was calculated by the sum of all dimensions. so _kldweight was actually 1/N=1/4096~0.00025

angelusualle commented 1 year ago

N is the dimensionality of input (for images, W*H)

Couldn't it be W * H * channels? Another part of that doc says

over the individual pixels xn

bkkm78 commented 6 months ago

This weight is needed when you use L2 loss as the reconstruction loss. L2 loss (aka MSE) means that you're assuming a Gaussian $p_{\theta}(x|z)$, for which you need to specify a $\sigma$ for the Gaussian distribution as a hyperparameter. This is where the relative weight between the reconstruction loss and the KL divergence comes from. If you instead assume a Bernoulli distribution and thus apply a (per pixel per channel) binary cross-entropy loss, this relative weight is not necessary.

You can refer to section 2.4.3 of Carl Doersch's tutorial on VAE for more details.