Open WaterKnight1998 opened 4 years ago
I updated code however F.l1_loss are not matching:
N, C, H, W = input_logits.shape
th = 1e-8 # 1e-10
eps = 1e-10
ignore_mask = (gts == ignore_pixel).detach()
input_logits = torch.where(ignore_mask.view(N, 1, H, W).expand(N, C, H, W),
torch.zeros(N,C,H,W).cuda(),
input_logits)
gt_semantic_masks = gts.detach()
gt_semantic_masks = torch.where(ignore_mask, torch.zeros(N,H,W).long().cuda(), gt_semantic_masks)
gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 19).detach()
g = _gumbel_softmax_sample(input_logits.view(N, C, -1), tau=0.5)
g = g.reshape((N, C, H, W))
g = compute_grad_mag(g, cuda=self._cuda)
g_hat = compute_grad_mag(gt_semantic_masks, cuda=self._cuda)
g = g.view(N, -1)
g_hat = g_hat.reshape(N, -1)
loss_ewise = F.l1_loss(g, g_hat, reduction='none', reduce=False)
p_plus_g_mask = (g >= th).detach().float()
loss_p_plus_g = torch.sum(loss_ewise * p_plus_g_mask) / (torch.sum(p_plus_g_mask) + eps)
p_plus_g_hat_mask = (g_hat >= th).detach().float()
loss_p_plus_g_hat = torch.sum(loss_ewise * p_plus_g_hat_mask) / (torch.sum(p_plus_g_hat_mask) + eps)
total_loss = 0.5 * loss_p_plus_g + 0.5 * loss_p_plus_g_hat
@ayinaaaaaa i am in binary segmentation!!
Hey @WaterKnight1998,
I was facing the same issue in a similar situation. One change I believe you should make is to replace the 19 by 2 as you are dealing with binary segmentation, thus 2 classes. If I assume correctly, the 19 corresponds to the 19 classes in cityscapes dataset. Replacing
gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 19).detach()
with
gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 2).detach()
works.
This should be enough to resolve the error.
gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 2).detach() here 2 can be your classes, i use this pass