christiancosgrove / pytorch-spectral-normalization-gan

Paper by Miyato et al. https://openreview.net/forum?id=B1QRgziT-
MIT License
676 stars 109 forks source link

About the G loss. #14

Closed zeakey closed 5 years ago

zeakey commented 5 years ago

I get confused about the G loss.

In the code of Wasserstein(https://github.com/martinarjovsky/WassersteinGAN/blob/f81eafd2aa41e93698f203732f8f395abc70be02/main.py#L212) the author use

errG = netD(fake)

where fake = netG(z).

However, in your implementation, the G loss is

gen_loss = -discriminator(generator(z)).mean().

Theoretically, I believe that the G loss should be -D(G(z)) because the G is expected to be able to 'cheat' the D.

christiancosgrove commented 5 years ago

I think these are the same, except for a choice of sign. In that WGAN code, we have the discriminator loss errD = errD_real - errD_fake and the generator loss errG = netD(fake)

whereas in this code we use disc_loss = -discriminator(data).mean() + discriminator(generator(z)).mean() and gen_loss = -discriminator(generator(z)).mean()

The sign is flipped in both loss functions, but the overall effect is the same.

It's important to note that in WGAN, we don't use cross-entropy loss. Therefore, the losses are invariant to a sign flip as long as we perform the sign flip consistently on the discriminator and generator loss functions.

zeakey commented 5 years ago

Yes I understand what you mean. However, intuitively I think your formulation is more strightforward where for true images f(x) is larger and for generated images f(x) is smaller.