ayukat1016 / gan_sample

MIT License
78 stars 25 forks source link

section5_1_pix2pixのGのGAN学習コード #26

Closed stationkl5 closed 4 months ago

stationkl5 commented 11 months ago

section5_1_pix2pixのPix2Pixクラスの定義の Pix2Pix.train のGのパラメータ更新部のコードの下記部分

    # Generator
    # 評価フェーズなので勾配は計算しない
    # 識別器Dに生成画像を入力
    with torch.no_grad():
        pred_fake = self.netD(fakeAB)

    # 生成器GのGAN損失を算出
    lossG_GAN = self.criterionGAN(pred_fake, True)

with torch.no_grad():部で計算グラフが切れて、GANロスが逆伝搬しない気がしますが 私の認識あっているでしょうか?

takaaki5564 commented 11 months ago

@stationkl5 返信が遅くなり申し訳ありません。

with torch.no_grad()を必要とした意図としては、このステップでは生成器Gの学習のためのパラメータ計算をしている段階であり、識別器Dのパラメータはここでは更新したくないためです。

もしwith torch.no_grad()が存在しない場合、self.netD(fakeAB)を計算する時に勾配が自動的に計算され、識別器Dのパラメータが意図に反して更新されてしまいます。そのためwith torch.no_grad()が必要となります。

stationkl5 commented 9 months ago

@takaaki5564 ご回答ありがとうございます。

生成機のパラメーターを、識別機をだませるパラメータに更新するために lossG_GANを計算して生成機に逆伝搬させたいところですが、  lossG_GAN=criterionGAN(self.netD(self.netG(A),B) を逆伝搬したくても  with torch.no_grad():   pred_fake = self.netD(fakeAB) としていると、netDの勾配がないため、netGまで逆伝搬できず、生成機のパラメーターも更新できないのでないでしょうか?