Open jszgz opened 5 years ago
def KL_loss(mu, logvar):
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.mean(KLD_element).mul_(-0.5) return KLD
What does this formula mean?
def KL_loss(mu, logvar):
-0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)