There is an issue with the computation of kl-divergence when using kl_loss_var function. I think the fix would be by removing the second ( before var_ratio_log.exp(). The update would look like:
Otherwise, using torch.distributions.kl_divergence(z_post_dist, z_prior_dist) where
z_prior_dist = torch.distributions.normal.Normal(mu_c, sigma_c) # mu_c, sigma_c are the computed mean and standard deviation using contexts
z_post_dist = torch.distributions.normal.Normal(mu_t, sigma_t) # mu_t, sigma_t are the computed mean and standard deviation using targets
would do the job.
PS. thank you for this open-source implementation! 👍
https://github.com/3springs/attentive-neural-processes/blob/016272a077a19bc51d145d1ad99d910477458876/neural_processes/utils.py#L167
There is an issue with the computation of kl-divergence when using
kl_loss_var
function. I think the fix would be by removing the second(
beforevar_ratio_log.exp()
. The update would look like:Otherwise, using
torch.distributions.kl_divergence(z_post_dist, z_prior_dist)
wherewould do the job.
PS. thank you for this open-source implementation! 👍