mingyuliutw / UNIT

Unsupervised Image-to-Image Translation
Other
1.98k stars 360 forks source link

Clarification on the KL Divergence term in the Generator loss for SHVN->MNIST model #42

Closed hsm207 closed 6 years ago

hsm207 commented 6 years ago

I have a question about the _compute_kl function in the class COCOGANDAContextTrainer. The following are the relevant parts of the code:

  def _compute_kl(self, mu, sd):
    mu_2 = torch.pow(mu, 2)
    sd_2 = torch.pow(sd, 2)
    encoding_loss = (mu_2 + sd_2 - torch.log(sd_2)).sum() / mu_2.size(0)
    return encoding_loss

This function was used in gen_update:

    for i, lt in enumerate(lt_codes):
      encoding_loss += 2 * self._compute_kl(*lt)
    total_loss = hyperparameters['gan_w'] * ad_loss + \
                 hyperparameters['kl_normalized_direct_w'] * encoding_loss + \
                 hyperparameters['ll_normalized_direct_w'] * (ll_loss_a + ll_loss_b)

My question is how did you derive the formula to compute the KL divergence term?

I thought it was based on the Auto-Encoding Variational Bayes paper which has the following parts:

image

and in Appendix B:

image

I note the following differences between the code and the paper (Auto-Encoding Variational Bayes):

  1. The KL divergence term is multiplied by 2 instead of 1/2. I guess this does not matter much since it just rescales the loss.

  2. There is no - 1 in the encoding_loss. Did you choose not to include this term because it will not change the optimum point anyway?

mingyuliutw commented 6 years ago

Yes, those were my considerations.