CompVis / stable-diffusion

A latent text-to-image diffusion model
https://ommer-lab.com/research/latent-diffusion-models/
Other
66.51k stars 9.97k forks source link

Question about the KL divergence loss #849

Open marctimjen opened 1 month ago

marctimjen commented 1 month ago

Hello

I hope someone can help me understand why the KL is calculated as: 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])

In the DiagonalGaussianDistribution listed here: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/distributions/distributions.py#L44

I am asking, because most loss functions for the VAE I can find use (-1 times this calculations) like this: 0.5 * torch.sum(-torch.pow(self.mean, 2) - self.var + 1.0 + self.logvar, dim=[1, 2, 3])

And I cannot see that we multiply by -1 in the contperceptual loss for instance:

https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/losses/contperceptual.py#L83C57-L83C65

Thank you very much in advance :)

marctimjen commented 1 month ago

I found this material that does also have the loss on the form that is used here:

https://pyimagesearch.com/2023/10/02/a-deep-dive-into-variational-autoencoders-with-pytorch/

My confusion just happen because most papers write:

image

(From Bishops Deep learning)

And the original implementation of VAE:

image