Open civilman628 opened 6 years ago
The network weights are first updated with d_loss = d_loss_real + d_loss_fake
: there is a backward() and optimizer_step() that happens. After this, it then checks if GP has to be included, and then computes d_loss_gp, assigns d_loss = self.lambda_gp * d_loss_gp
, and then does another backward() and optimizer_step(). This is not necessarily wrong.. But-
Ideally it should have been:
d_loss = d_loss_real + d_loss_fake
if self.adv_loss == 'wgan-gp':
d_loss_gp = compute_gradient_penalty(real_images, real_labels, fake_images)
d_loss += self.lambda_gp * d_loss_gp
# Backward + Optimize
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
please ref:
https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan_gp/wgan_gp.py
if choose 'wgan-gp'in d_loss and GP are not update differently in train()
d_loss = d_loss_real + d_loss_fake
d_loss = self.lambda_gp * d_loss_gp