mingyuliutw / UNIT

Unsupervised Image-to-Image Translation
Other
1.98k stars 360 forks source link

A puzzle about gen_update function #22

Closed ssdutHB closed 6 years ago

ssdutHB commented 6 years ago
def gen_update(self, images_a, images_b, hyperparameters):
self.gen.zero_grad()
x_aa, x_ba, x_ab, x_bb, shared = self.gen(images_a, images_b)
x_bab, shared_bab = self.gen.forward_a2b(x_ba)
x_aba, shared_aba = self.gen.forward_b2a(x_ab)
outs_a, outs_b = self.dis(x_ba,x_ab)
for it, (out_a, out_b) in enumerate(itertools.izip(outs_a, outs_b)):
outputs_a = nn.functional.sigmoid(out_a)
outputs_b = nn.functional.sigmoid(out_b)
all_ones = Variable(torch.ones((outputs_a.size(0))).cuda(self.gpu))
if it==0:
ad_loss_a = nn.functional.binary_cross_entropy(outputs_a, all_ones)
ad_loss_b = nn.functional.binary_cross_entropy(outputs_b, all_ones)
else:
ad_loss_a += nn.functional.binary_cross_entropy(outputs_a, all_ones)
ad_loss_b += nn.functional.binary_cross_entropy(outputs_b, all_ones)

The code above is a part of code in cocogan_trainer.py. I think the all_ones = Variable(torch.ones((outputs_a.size(0))).cuda(self.gpu)) should be all_zeros = Variable(torch.zeros((outputs_a.size(0))).cuda(self.gpu)) Because it calculates the loss when the inputs of Discriminator are fakeA and fakeB. Is my understanding right? Do I misunderstand it?

mingyuliutw commented 6 years ago

Intuitively, the goal of the generator is to deceive the discriminator. Hence, it needs to be updated in a way the discriminator will output ones for the generated images. Instead of optimizing the original GAN objective, the way it is implemented in the release leads to more stable training. This is a common way of implementing the GAN learning algorithm. You could check out Goodfellow et al.'s NIPS'14 paper to see the description.

ssdutHB commented 6 years ago

Thank you very much. I get it!