researchmm / AOT-GAN-for-Inpainting

[TVCG'2023] AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)
https://arxiv.org/abs/2104.01431
Apache License 2.0
425 stars 68 forks source link

Confuse about the my_layer_norm function and GAN loss function #2

Open SeeU1119 opened 3 years ago

SeeU1119 commented 3 years ago

Hi, thank you for your excellent work, I get a good result when run the demo. But when I read the source code, I get some problems about GAN loss function. image In the paper, the dis loss about fake_img should be self.loss_fn(d_fake, gauss(1 - mask)), but I find you just do gauss(mask), Is there something wrong with my understanding? whatmore, the dis loss about real_img should be self.loss_fn(d_real, d_real_label), where d_real_label is torch.ones(...), but you write it to torch.zeros(...). By the way, could you explain the work of my_layer_norm in the AOT block? Thanks.

zjlinkin commented 3 years ago

I have the same confuse about the gan loss part... the gan loss in the code seem not do adversarial train

ewrfcas commented 3 years ago

The same question

964728623 commented 5 months ago

g_fake_label = torch.ones_like(g_fake).cuda() is wrong for gan, g_fake_label = torch.zeros_like(g_fake).cuda() is right. The author handle the wrong with parser.add_argument('--adv_weight', type=float, default=0.01,help='loss weight for adversarial loss'),no use for adv loss for training for netG.

964728623 commented 5 months ago
    d_fake_label = gaussian_blur(masks, (self.ksize, self.ksize), (10, 10)).detach().cuda()
    d_real_label = torch.zeros_like(d_real).cuda()
    # g_fake_label = torch.ones_like(g_fake).cuda()
    g_fake_label = torch.zeros_like(g_fake).cuda()

    dis_loss = self.loss_fn(d_fake[masks>0.5], d_fake_label[masks>0.5]) + self.loss_fn(d_real, d_real_label)
    gen_loss = self.loss_fn(g_fake[masks>0.5], g_fake_label[masks>0.5])