nv-tlabs / GSCNN

Gated-Shape CNN for Semantic Segmentation (ICCV 2019)
https://nv-tlabs.github.io/GSCNN/
Other
919 stars 202 forks source link

DualTaskLoss not working #72

Open WaterKnight1998 opened 4 years ago

WaterKnight1998 commented 4 years ago
~/Documents/pro1/seg/utils/loss_gscnn.py in forward(self, inputs, targets)
    344         losses['edge_loss'] = self.edge_weight * 20 * self.bce2d(edgein, edgemask)
    345         losses['att_loss'] = self.att_weight * self.edge_attention(segin, segmask, edgein)
--> 346         losses['dual_loss'] = self.dual_weight * self.dual_task(segin, segmask)
    347 
    348         return losses

~/anaconda3/envs/seg/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    556             result = self._slow_forward(*input, **kwargs)
    557         else:
--> 558             result = self.forward(*input, **kwargs)
    559         for hook in self._forward_hooks.values():
    560             hook_result = hook(self, input, result)

~/Documents/pro1/seg/utils/loss_gscnn.py in forward(self, input_logits, gts, ignore_pixel)
    240         input_logits = torch.where(ignore_mask.view(N, 1, H, W).expand(N, 19, H, W),
    241                                    torch.zeros(N,C,H,W).cuda(),
--> 242                                    input_logits)
    243         gt_semantic_masks = gts.detach()
    244         gt_semantic_masks = torch.where(ignore_mask, torch.zeros(N,H,W).long().cuda(), gt_semantic_masks)

RuntimeError: The size of tensor a (19) must match the size of tensor b (2) at non-singleton dimension 1
WaterKnight1998 commented 4 years ago

I updated code however F.l1_loss are not matching:

N, C, H, W = input_logits.shape
        th = 1e-8  # 1e-10
        eps = 1e-10
        ignore_mask = (gts == ignore_pixel).detach()
        input_logits = torch.where(ignore_mask.view(N, 1, H, W).expand(N, C, H, W),
                                   torch.zeros(N,C,H,W).cuda(),
                                   input_logits)
        gt_semantic_masks = gts.detach()
        gt_semantic_masks = torch.where(ignore_mask, torch.zeros(N,H,W).long().cuda(), gt_semantic_masks)
        gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 19).detach()

        g = _gumbel_softmax_sample(input_logits.view(N, C, -1), tau=0.5)
        g = g.reshape((N, C, H, W))
        g = compute_grad_mag(g, cuda=self._cuda)

        g_hat = compute_grad_mag(gt_semantic_masks, cuda=self._cuda)

        g = g.view(N, -1)
        g_hat = g_hat.reshape(N, -1)
        loss_ewise = F.l1_loss(g, g_hat, reduction='none', reduce=False)

        p_plus_g_mask = (g >= th).detach().float()
        loss_p_plus_g = torch.sum(loss_ewise * p_plus_g_mask) / (torch.sum(p_plus_g_mask) + eps)

        p_plus_g_hat_mask = (g_hat >= th).detach().float()
        loss_p_plus_g_hat = torch.sum(loss_ewise * p_plus_g_hat_mask) / (torch.sum(p_plus_g_hat_mask) + eps)

        total_loss = 0.5 * loss_p_plus_g + 0.5 * loss_p_plus_g_hat
WaterKnight1998 commented 4 years ago

@ayinaaaaaa i am in binary segmentation!!

ShreyasHavaldar7 commented 4 years ago

Hey @WaterKnight1998, I was facing the same issue in a similar situation. One change I believe you should make is to replace the 19 by 2 as you are dealing with binary segmentation, thus 2 classes. If I assume correctly, the 19 corresponds to the 19 classes in cityscapes dataset. Replacing
gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 19).detach() with
gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 2).detach() works.

This should be enough to resolve the error.

reacher-l commented 4 years ago

gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 2).detach() here 2 can be your classes, i use this pass