NVlabs / NVAE

The Official PyTorch Implementation of "NVAE: A Deep Hierarchical Variational Autoencoder" (NeurIPS 2020 spotlight paper)
https://arxiv.org/abs/2007.03898
Other
999 stars 163 forks source link

Could you explain kl_balancer in detail? #26

Closed bfs18 closed 3 years ago

bfs18 commented 3 years ago

https://github.com/NVlabs/NVAE/blob/38eb9977aa6859c6ee037af370071f104c592695/utils.py#L213 * total_kl in this line is redundant because in the next line kl_coeff_i is divided by its mean. Why mean(abs(kl_i)) is used as a weight factor for kl_i ? Is / alpha_i equivalent to * num_group / feature_resolution?

arash-vahdat commented 3 years ago

Please check section H in DVAE++: https://arxiv.org/pdf/1802.04920.pdf and also section A in NVAE.

The basic idea is to multiply the KL term for each group with a scalar coefficient such that groups that have low KL are encouraged to use more latent variables (hence increase their KL) and groups that have high KL are encouraged to reduce it. We set this coefficient proportional to the average KL per group in a batch divided by the dimension of each group.

The problem with this coefficient is that we cannot just set it to any value and we need to normalize it such that it sums to N (N=number of groups). This way our KL term is scaled properly compared to the reconstruction term. When all the groups are used equally kl_coeff_i becomes ~1 for each group.

The operations you noted are redundant indeed. We just basically need to do:

kl_coeff_i = kl_coeff_i / feature_resolution                    # division by latent resolution
kl_coeff_i = num_group * kl_coeff_i / torch.sum(kl_coeff_i, dim=1, keepdim=True)  # normalization such that sum = num_group
kl = torch.sum(kl_all * kl_coeff_i.detach(), dim=1)

I also noticed that in the paper we say that we set the coefficient proportional to the size of each group on page 14 which is a typo and it should have been reverse proportional to the size of each group.

bfs18 commented 3 years ago

Thanks a lot for your detailed explanation. It is really helpful.