eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.17k stars 4.05k forks source link

Gradient of a Discriminator in optimizing a Generator #179

Open YoojLee opened 2 years ago

YoojLee commented 2 years ago
    # -----------------
    #  Train Generator
    # -----------------

    optimizer_G.zero_grad()

    # Sample noise and labels as generator input
    z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
    gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))

    # Generate a batch of images
    gen_imgs = generator(z, gen_labels)

    # Loss measures generator's ability to fool the discriminator
    validity = discriminator(gen_imgs, gen_labels)
    g_loss = adversarial_loss(validity, valid)

    g_loss.backward()
    optimizer_G.step()

    # ---------------------
    #  Train Discriminator
    # ---------------------

    optimizer_D.zero_grad()

    # Loss for real images
    validity_real = discriminator(real_imgs, labels)
    d_real_loss = adversarial_loss(validity_real, valid)

    # Loss for fake images
    validity_fake = discriminator(gen_imgs.detach(), gen_labels)
    d_fake_loss = adversarial_loss(validity_fake, fake)

In the code above, generated images are detached from computational graph when optimizing D due to the gradient of G should not be calculated when optimizing D. I was quite sure that, as same as optimizing D, Discriminator should be isolated from G when optimizing G (by controlling requires_grad attributes of parameters in D or something).

However, I am confused that the discriminator is apparent to be not detached from the generator in your code. Do I understand something wrong? I will wait for your reply. Thanks a lot!

baicaiPCX commented 2 years ago

您好,您的邮件小皮已经收到啦!!!

ahmedemam576 commented 1 year ago

@YoojLee could you understand it now? I am having the same question

YoojLee commented 1 year ago

@ahmedemam576 Well not clearly, but I guess it might be up to which dataset you use.. still not sure tho.