eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.39k stars 4.07k forks source link

Bug in WGAN-GP ? #136

Open pfeatherstone opened 3 years ago

pfeatherstone commented 3 years ago

Should the following https://github.com/eriklindernoren/PyTorch-GAN/blob/a163b82beff3d01688d8315a3fd39080400e7c01/implementations/wgan_gp/wgan_gp.py#L162 be this instead:

fake_imgs = generator(z).detach()

?

ALLinLLM commented 3 years ago

图片

the argorithm in the wgan-gp paper is above. There are two ways to calulate G and D iteratively in pytorch:

  1. 
    # update D
    fake = G(z)
    fake.detach()  # add detach() so the backward() will NOT change the weights of G
    loss_d = D(fake) - D(real)
    loss_d.backward()
    optimizerD.step()  # update the weights of D

then update G

loss_g = -D(fake) loss_g.backward() optimizerG.step() # update the weights of G

2.

update D

fake = G(z)

no detach() here so the backward() will change the weights of G

loss_d = D(fake) - D(real) loss_d.backward(retain_graph=True) # Otherwise, the graph will be delete so the loss_g.backward() will failed optimizerD.step() # update the weights of D

then update G

loss_g = -D(fake) loss_g.backward() optimizerG.step() # update the weights of G

annan-tang commented 1 year ago

Hi, it may be too late to answering this question. In my opinion, it isn't a bug but can slow down the computation speed. In the original code of wgan-gp, it use separate optimizers for G and D. When you are training the D by use the fake_image from G without detach this data(tensor), the loss_d.backward() will calculate the gradients for both D and G parts( even G part isn't necessary in D training stage), but in D training stage, the source code only update the parameters of D by optimizerD.step() and at the beginning of G training stage, the source code first clear the gradients calculated from D training stage with zero_grad() . so it's not a bug, but add additional computational overhead on D training. You can add the .detach() to make your code more efficient!