Open jxtps opened 2 years ago
I guess this would kind of mimic how the D optimization uses both the fake & the real data to have the gradients cancel out where they're the same ( https://github.com/cszn/KAIR/blob/master/models/model_gan.py#L249 ):
# real
pred_d_real = self.netD(self.H) # 1) real data
l_d_real = self.D_lossfn(pred_d_real, True)
l_d_real.backward()
# fake
pred_d_fake = self.netD(self.E.detach().clone()) # 2) fake data, detach to avoid BP to G
l_d_fake = self.D_lossfn(pred_d_fake, False)
l_d_fake.backward()
In model_gan.py we have:
But when D & G have converged,
pred_g_fake
is likely to be close to midway between fake & real. This means that even for pixels / examples where G has done a great job, there will still be gradients yanking it around.This is somewhat complicated by the fact that some GAN losses use
BCEWithLogitsLoss
(= they sigmoid the input) andlsgan
usesMSELoss
(which doesn't sigmoid the input), so the meaning of the input differs between them.Anyway, something like:
This is just meant as an illustration, not a pull request. It obviously does very different things for
gan
(it cuts at the midpoint, but doesn't stretch, because it's unclear how to stretch prior to the sigmoid) vslsgan
(where it stretches [0-0.5] => [0-1] - it gets more complicated if the real label != 1 or fake label != 0).But it should enable the gradients to "calm down" when G has done a good job and D is maximally confused.
I guess what I'm saying is that while the objective for D should indeed be to output 0 or 1 (ignoring the sigmoid), the goal for G is less to have D output 1 and more to have it output >=0.5.
Thoughts?