eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.41k stars 4.07k forks source link

WGAN GP detach is necessary? #21

Closed Oktai15 closed 6 years ago

Oktai15 commented 6 years ago

https://github.com/eriklindernoren/PyTorch-GAN/blob/22ce15edd1abeb4f735be11592569720e2dd3018/implementations/wgan_gp/wgan_gp.py#L122

@eriklindernoren I think you should put .detach() to real_samples and fake_samples. Isn't it?

eriklindernoren commented 6 years ago

They are already detached since they are fed to the function as real_imgs.data, fake_imgs.data.

Oktai15 commented 6 years ago

@eriklindernoren I am not sure about it. https://github.com/pytorch/pytorch/issues/6990

What do you think about this issue?

eriklindernoren commented 6 years ago

I see no problem in this case (although using .detach() wouldn't hurt). Neither real_samples or fake_samples are used for the backward pass. I.e. the fact that real_samples.detach() has the benefit of having it's in-place changes reported by autograd should make no difference as real_samples is not used.

cwarny commented 5 years ago

Wouldn't it be better to detach the fake data here:

https://github.com/eriklindernoren/PyTorch-GAN/blob/a163b82beff3d01688d8315a3fd39080400e7c01/implementations/wgan_gp/wgan_gp.py#L167

fake_validity = discriminator(fake_imgs.detach())

We don't want to calculate gradients for the generator when doing backpropagation on the critic:

https://github.com/eriklindernoren/PyTorch-GAN/blob/a163b82beff3d01688d8315a3fd39080400e7c01/implementations/wgan_gp/wgan_gp.py#L173

Although it doesn't matter in practice since we later zero out the gradients here:

https://github.com/eriklindernoren/PyTorch-GAN/blob/a163b82beff3d01688d8315a3fd39080400e7c01/implementations/wgan_gp/wgan_gp.py#L176

It still introduces unnecessary computation. Or is there a reason we don't want to detach that I'm missing?

annan-tang commented 1 year ago

@cwarny I agree with you! My answer is similar to yours, here