znxlwm / UGATIT-pytorch

Official PyTorch implementation of U-GAT-IT: Unsupervised Generative Attentional Networks with Adaptive Layer-Instance Normalization for Image-to-Image Translation
MIT License
2.51k stars 476 forks source link

How to understand the CAM loss for the generator including fake_B2A and fake_A2A? #39

Open qiliux7 opened 4 years ago

qiliux7 commented 4 years ago
G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device))
G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device))

As for the generator, what does the cam want to classify? Classify the real image to 0? Classify the fake image to 1?


As for the discriminator, we can see that the real images will be classified to 1 and the fake images will be classified to 0. Therefore, I am confusing about the CAM for the generator.

annihi1ation commented 4 years ago

I've read the paper. According to that, it seems like they want to tune the feature maps in the decoding part be more like the images from the target domain rather than the source domain.