AntixK / PyTorch-VAE

A Collection of Variational Autoencoders (VAE) in PyTorch.
Apache License 2.0
6.44k stars 1.05k forks source link

Mistake in Vanilla VAE loss #69

Closed Midren closed 1 year ago

Midren commented 1 year ago

Comment for KL divergence between a latent distribution and standards distribution correct: https://github.com/AntixK/PyTorch-VAE/blob/a6896b944c918dd7030e7d795a8c13e5c6345ec7/models/vanilla_vae.py#L129 but there is probably a mistake in the Python code: https://github.com/AntixK/PyTorch-VAE/blob/a6896b944c918dd7030e7d795a8c13e5c6345ec7/models/vanilla_vae.py#L143.

log_var variable is used for log sigma and log_var.exp() is used for sigma^2.

I suppose that log_var should be log_sigma, and log_var.exp() should be changed to log_sigma.exp()^2.

NicholasKX commented 1 year ago

I think there is no mistake.

Elkortya commented 1 year ago

I think there is no mistake either, the comment is consistent with the code.

As log_var is linked to std as

std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = eps * std + mu

we can see that $logvar = 2 log(\sigma) = log(\sigma^2)$.

Starting from the definition of KL divergence in the comment, $\Sigma (\log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}))$ we get $\Sigma (-\log \sigma + \frac{\sigma^2}{2} + \frac{\mu^2}{2} - \frac{1}{2})$

$-\frac{1}{2} [\Sigma (2 \log \sigma - \sigma^2 - \mu^2 +1)]$

$-\frac{1}{2} [(\Sigma \log \sigma^2 - \sigma^2 - \mu^2 + 1)]$

and replacing by log_var $-\frac{1}{2} [(\Sigma logvar - e^{logvar} - \mu^2 + 1)]$

which is what is written in the code.

Midren commented 1 year ago

I see now, thank you for your explanation!