princeton-vl / CornerNet

BSD 3-Clause "New" or "Revised" License
2.36k stars 475 forks source link

Some question about _neg_loss #63

Open wangxiaodong1021 opened 5 years ago

wangxiaodong1021 commented 5 years ago

Hello, I have some confusion about the loss function (_neg_loss).

def _neg_loss(preds, gt):
    pos_inds = gt.eq(1)
    neg_inds = gt.lt(1)

    neg_weights = torch.pow(1 - gt[neg_inds], 4)

    loss = 0
    for pred in preds:
        pos_pred = pred[pos_inds]
        neg_pred = pred[neg_inds]

        pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)
        neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights

        num_pos  = pos_inds.float().sum()
        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()

        if pos_pred.nelement() == 0:
            loss = loss - neg_loss
        else:
            loss = loss - (pos_loss + neg_loss) / num_pos
    return loss

I found the neg_weights is always 1. So, neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights equivalent to neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) .

Is my understanding correct?

heilaw commented 5 years ago

The elements in neg_weights are not always 1s. Weights at the locations close to the corners are less than 1, while the rest are 1s. There are usually only a few objects in an image so most of the elements in neg_weights would be 1s. When you print neg_weights, PyTorch only shows a small part of it and those elements are very likely to be 1s.