GlassyWing / nvae

An unofficial toy implementation for NVAE 《A Deep Hierarchical Variational Autoencoder》
Apache License 2.0
108 stars 21 forks source link

Minor bug in kl2? #3

Closed samedii closed 3 years ago

samedii commented 4 years ago

Judging from the first kl I think this was not intended. It may not have any significant effect however except for maybe M_N?

def kl_2(delta_mu, delta_log_var, mu, log_var):
    var = torch.exp(log_var)
    delta_var = torch.exp(delta_log_var)

    loss = -0.5 * torch.sum(1 + delta_log_var - delta_mu ** 2 / var - delta_var)
    return torch.mean(loss, dim=0)

Alternative:

def relative_kl(delta_mu, delta_log_var, mu, log_var):
    var = torch.exp(log_var)
    delta_var = torch.exp(delta_log_var)

    loss = -0.5 * (
        1 + delta_log_var - delta_mu ** 2 / var - delta_var
    )
    return loss.flatten(start_dim=1).sum(dim=-1).mean(dim=0)
GlassyWing commented 4 years ago

Thanks for report, I did lose dim in torch.sum