zfjsail / gae-pytorch

Graph Auto-Encoder in PyTorch
MIT License
417 stars 79 forks source link

Some question of KLD #3

Open CXX1113 opened 5 years ago

CXX1113 commented 5 years ago

KLD = -0.5 / n_nodes * torch.mean(torch.sum(1 + 2 * logvar - mu.pow(2) - logvar.exp().pow(2), 1)) / n_nodes should be removed or torch.mean → torch.sum

AllenWu18 commented 5 years ago

then is it should be KLD = -0.5 torch.mean(torch.sum(1 + 2 logvar - mu.pow(2) - logvar.exp().pow(2), 1)) or KLD = -0.5 / n_nodes torch.sum(torch.sum(1 +2 logvar - mu.pow(2)-logvar.exp().pow(2) , 1))?

YH-UtMSB commented 5 years ago

@alanlisten @AllenWu18 you both are right about the KLD. The author has clarified that 1 / n_nodes serves as a rescaling parameter (like \beta in beta-vae) to weaken the regularization from KLD. Check this issue: Loss function in optimizer.py #20.

Dzhilin commented 3 years ago

Please tell me what is the meaning of "norm" in the loss function? Looking forward to your reply!