Closed stationkl5 closed 4 months ago
@stationkl5 返信が遅くなり申し訳ありません。
with torch.no_grad()を必要とした意図としては、このステップでは生成器Gの学習のためのパラメータ計算をしている段階であり、識別器Dのパラメータはここでは更新したくないためです。
もしwith torch.no_grad()が存在しない場合、self.netD(fakeAB)を計算する時に勾配が自動的に計算され、識別器Dのパラメータが意図に反して更新されてしまいます。そのためwith torch.no_grad()が必要となります。
@takaaki5564 ご回答ありがとうございます。
生成機のパラメーターを、識別機をだませるパラメータに更新するために lossG_GANを計算して生成機に逆伝搬させたいところですが、 lossG_GAN=criterionGAN(self.netD(self.netG(A),B) を逆伝搬したくても with torch.no_grad(): pred_fake = self.netD(fakeAB) としていると、netDの勾配がないため、netGまで逆伝搬できず、生成機のパラメーターも更新できないのでないでしょうか?
section5_1_pix2pixのPix2Pixクラスの定義の Pix2Pix.train のGのパラメータ更新部のコードの下記部分
with torch.no_grad():部で計算グラフが切れて、GANロスが逆伝搬しない気がしますが 私の認識あっているでしょうか?