GlassyWing / nvae

An unofficial toy implementation for NVAE 《A Deep Hierarchical Variational Autoencoder》
Apache License 2.0
108 stars 21 forks source link

hierarchical KL loss #17

Open ycyang18 opened 11 months ago

ycyang18 commented 11 months ago

Hi! Very Impressive work, thanks for sharing! I have a question regarding the hierarchical KL loss. As in the original paper, the hierarchical kl loss is stated as:

∑_L [KL(q(zl | x, z(l-1)) || p(zl | z(l-1)))]

,between encoder and decoder.

I am wording why did you model the KL loss between p(zl | x, z(l-1)) and p(zl | z(l-1)), which both are from decoder? mu, log_var = self.condition_z[i](decoder_out).chunk(2, dim=1) delta_mu, delta_log_var = self.condition_xz[i](torch.cat([xs[i], decoder_out], dim=1)).chunk(2, dim=1) kl_losses.append(kl_2(delta_mu, delta_log_var, mu, log_var))

Please let me know if there are any misunderstandings. Thanks a lot in advance!:)