eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.22k stars 4.05k forks source link

Use of detach #151

Open Indraa145 opened 3 years ago

Indraa145 commented 3 years ago

Hello, I'm currently learning about PyTorch and GAN, I want to ask about this particular lines from the WGAN implementation here.

# Configure input
real_imgs = Variable(imgs.type(Tensor))

# ---------------------
#  Train Discriminator 
# ---------------------
optimizer_D.zero_grad()

# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

# Generate a batch of images
fake_imgs = generator(z).detach()
# Adversarial loss
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
loss_D.backward()
optimizer_D.step()

Why is the fake_imgs have .detach() function while the real_imgs doesn't? What would happen if I put .detach() on the real_imgs as well, will that mess with the discriminator update? Thank you.

chandragupta0001 commented 3 years ago

It is due to computational graph, when fake image is generated computational graph contain all "events" from latent variables to final fake image. we don't want to make a computation graph on top of it and then backpropagate to entire graph( G+D) but only D. so you detach or make a separate copy of fake image and do computation on it. real image is input also do not required grad and hence doesn't play a role in backpropagation ,etc

Indraa145 commented 3 years ago

It is due to computational graph, when fake image is generated computational graph contain all "events" from latent variables to final fake image. we don't want to make a computation graph on top of it and then backpropagate to entire graph( G+D) but only D. so you detach or make a separate copy of fake image and do computation on it. real image is input also do not required grad and hence doesn't play a role in backpropagation ,etc

Thank you for your response, I see. So, it doesn't matter whether I put detach on the real_imgs or not in this case, right? Whether I do this real_imgs = Variable(imgs.type(Tensor)) or real_imgs = Variable(imgs.type(Tensor)).detach() in this case it will give out the same result?

Also, it seems I forgot to specify where the code lines came from. It's from the WGAN implementation at the discriminator training part.