lucidrains / lightweight-gan

Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two
MIT License
1.62k stars 220 forks source link

loss implementation differs from paper #128

Open maximeraafat opened 2 years ago

maximeraafat commented 2 years ago

Hi,

Thanks for this amazing implementation! I have a question concerning the loss implementation, as it seems to differ from the original equations. The screenshot below shows the GAN loss as presented in the paper :

paper_losses

This makes sense to me. Since it is assumed that D outputs values between 0 and 1 (0 = fake, 1 = real) :

Now, the way the authors implement this in the code provided in the supplementary materials of the paper is as follows (the colors match the ones in the above picture)

og_code_loss_d_real og_code_loss_d_fake og_code_loss_g

Except for the strange involved randomness (already explained in https://github.com/lucidrains/lightweight-gan/issues/11), their implementation is a one to one match with the paper equations.


The way it is implemented in this repo however is quite different, and I do not understand why..

lighweight_gan_losses

Let's start with the discriminator loss :

For the generator loss :

This implementation seems to be meaningful, and yields coherent results (as proven in examples). It also seems to me that D is not limited to output values between 0 and 1, but any real value (I might be wrong). I am just wondering why this choice? Could you perhaps elaborate why you decided to implement the loss differently from the original paper?

iScriptLex commented 2 years ago

I think it was just taken from some other article. You can see some elements of WGAN-GP in this code, such as simplified realization of gradient penalty. Also, this code contains multiple losses (user can use dual contrastive loss instead of hinge loss). It could be implemented in such a way as to be able to use one training loop code for several loss functions.