Closed joe1chief closed 5 years ago
def hard_cross_entropy(output, target, alpha=3.0): mtx = F.cross_entropy(output, target, reduce=False) bg = (target == 0) # background neg = mtx[bg] pos = mtx[1-bg] Np, Nn = pos.numel(), neg.numel() pos = pos.sum() k = min(Np*alpha, Nn) if k > 0: neg, _ = torch.topk(neg, int(k)) neg = neg.sum() else: neg = 0.0 loss = (pos + neg)/ (Np + k) return loss
Please refer to our paper for more details.