Open JianLiu91 opened 7 years ago
I think this function will compute the right KL divergence assuming diagonal covariances.
def kl_divergence(mu_infer, log_var_infer, mu_prior, log_var_prior):
all_vals=1.0 + log_var_infer - log_var_prior - torch.exp(log_var_infer - log_var_prior) - (mu_infer-mu_prior)**2/torch.exp(log_var_prior)
kl_for_batches=torch.sum(-0.5*(all_vals), 1) #per batch
return torch.mean(kl_for_batches)
Is the implementation correct?
The computation of KL Divergence in your code is
However, this is the KL Divergence of a gaussian and a standard gaussion. According to the paper, the KL Divergence is about two non-standard gaussian.