wassname / attentive-neural-processes

implementing "recurrent attentive neural processes" to forecast power usage (w. LSTM baseline, MCDropout)
Apache License 2.0
90 stars 23 forks source link

Fixing the error in kl_loss_var function #5

Closed orisenbazuru closed 2 years ago

orisenbazuru commented 2 years ago

https://github.com/3springs/attentive-neural-processes/blob/016272a077a19bc51d145d1ad99d910477458876/neural_processes/utils.py#L167

There is an issue with the computation of kl-divergence when using kl_loss_var function. I think the fix would be by removing the second ( before var_ratio_log.exp(). The update would look like:

def kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post):
    var_ratio_log = log_var_post - log_var_prior
    kl_div = (
         var_ratio_log.exp() + ((post_mu - prior_mu) ** 2) / log_var_prior.exp()
        - 1.0
        - var_ratio_log
       )
    kl_div = 0.5 * kl_div

Otherwise, using torch.distributions.kl_divergence(z_post_dist, z_prior_dist) where

z_prior_dist =  torch.distributions.normal.Normal(mu_c, sigma_c) # mu_c, sigma_c are the computed mean and standard deviation using contexts
z_post_dist =  torch.distributions.normal.Normal(mu_t, sigma_t) # mu_t, sigma_t are the computed mean and standard deviation using targets

would do the job.

PS. thank you for this open-source implementation! 👍

wassname commented 2 years ago

The best part about sharing open source stuff is seeing improvement like this. Good catch thanks!