heykeetae / Self-Attention-GAN

Pytorch implementation of Self-Attention Generative Adversarial Networks (SAGAN)
2.53k stars 475 forks source link

i think the implementation of WGAN-GP loss is wrong #27

Open civilman628 opened 6 years ago

civilman628 commented 6 years ago

please ref:

https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py

if choose 'wgan-gp'in d_loss and GP are not update differently in train()

d_loss = d_loss_real + d_loss_fake

d_loss = self.lambda_gp * d_loss_gp

voletiv commented 6 years ago

The network weights are first updated with d_loss = d_loss_real + d_loss_fake: there is a backward() and optimizer_step() that happens. After this, it then checks if GP has to be included, and then computes d_loss_gp, assigns d_loss = self.lambda_gp * d_loss_gp, and then does another backward() and optimizer_step(). This is not necessarily wrong.. But-

Ideally it should have been:

d_loss = d_loss_real + d_loss_fake
if self.adv_loss == 'wgan-gp':
    d_loss_gp = compute_gradient_penalty(real_images, real_labels, fake_images)
    d_loss += self.lambda_gp * d_loss_gp

# Backward + Optimize
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()