Open XiaoyanQian opened 1 year ago
Hi, could you help to figure out the p0_z for KL loss in the following code:
kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1) loss = kl.mean()
How can I get the p0_z? Any thoughts?
Hi, could you help to figure out the p0_z for KL loss in the following code:
kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1) loss = kl.mean()
How can I get the p0_z? Any thoughts?