kingofspace0wzz / wae-rnf-lm

Pytorch Implemetation for our NAACL2019 Paper "Riemannian Normalizing Flow on Variational Wasserstein Autoencoder for Text Modeling" https://arxiv.org/abs/1904.02399
MIT License
62 stars 4 forks source link

Query about the KL divergence calculation with RNF #3

Closed vikigenius closed 5 years ago

vikigenius commented 5 years ago

Hi,I have been reading the paper and looking at the code, and I don't understand how the KL is being calculated in the loss term.

The loss calculation has a KL term with flow: (q_z.log_prob(z0).sum() - p_z.log_prob(z).sum())

Can you explain why z0 is used for posterior and z is used for prior? And the flow_kld function seems to use z for both, how is that different ?

kingofspace0wzz commented 5 years ago

Hi @vikigenius , this is a good question. Thanks for asking.

The naming is a little misleading. flow_kld measures the KL divergence between q(z_T|x) and p(z_T), which is not the KL you listed from the loss function. The KL you mentioned is based on q(z_0|x) and p(z_T). Using flow_kld in loss function is not correct. To see why, here is a complete derivation,

In short, once you apply normalizing flow, and you include the summation of log Jacobian in your loss function, then the posterior in KL should be q(z_0|x) rather than q(z_T|x). The same derivation can also be found in other normalizing flow papers.

Let me know if this answers your question. Thanks again!

vikigenius commented 5 years ago

Thanks for the clear derivation, now it's clear.