Open pankSM opened 4 years ago
I think that self.err_g_lat is appear only in generator,pls help me to correct my idea
def backward_g(self): """ Backpropagate netg """ self.err_g_adv = self.opt.w_adv self.l_adv(self.pred_fake, self.real_label) self.err_g_con = self.opt.w_con self.l_con(self.fake, self.input)
self.err_g = self.err_g_adv + self.err_g_con #+ self.err_g_lat
self.err_g.backward(retain_graph=True)
def backward_d(self):
# Fake
pred_fake, _ = self.netd(self.fake.detach())
self.err_d_fake = self.l_adv(pred_fake, self.fake_label)
# Real
# pred_real, feat_real = self.netd(self.input)
self.err_d_real = self.l_adv(self.pred_real, self.real_label)
# Combine losses.
self.err_g_lat = self.opt.w_lat * self.l_lat(self.feat_fake, self.feat_real)
self.err_d = self.err_d_real + self.err_d_fake + self.err_g_lat
self.err_d.backward(retain_graph=True)
In skipgnomaly, there is a self.err_d_lat in backward_d, if I delete self.err_g_lat, the programmer can run normaly;there is a issue in programmer if I don't change it.Please hele me!
It seems that the version of torch or torchvision is wrong.
In skipgnomaly, there is a self.err_d_lat in backward_d, if I delete self.err_g_lat, the programmer can run normaly;there is a issue in programmer if I don't change it.Please hele me!
It seems that the version of torch or torchvision is wrong.
I add "detach()" to "self.err_g_lat", and it run normaly.
it works:
with torch.no_grad():
self.pred_fake, self.feat_fake = self.netd(self.fake)
In skipgnomaly, there is a self.err_d_lat in backward_d, if I delete self.err_g_lat, the programmer can run normaly;there is a issue in programmer if I don't change it.Please hele me!