CoinCheung / pytorch-loss

label-smooth, amsoftmax, partial-fc, focal-loss, triplet-loss, lovasz-softmax. Maybe useful
MIT License
2.17k stars 374 forks source link

BCEWithLogitsLoss has combined a Sigmoid layer and the BCELoss in one single class, But why to use torch.sigmoid again #10

Closed sunpeng981712364 closed 3 years ago

sunpeng981712364 commented 4 years ago

version 1: use torch.autograd

class FocalLossV1(nn.Module):

def __init__(self,
             alpha=0.25,
             gamma=2,
             reduction='mean',):
    super(FocalLossV1, self).__init__()
    self.alpha = alpha
    self.gamma = gamma
    self.reduction = reduction
    self.crit = nn.BCEWithLogitsLoss(reduction='none')

def forward(self, logits, label):
    '''
    args:
        logits: tensor of shape (N, ...)
        label: tensor of shape(N, ...)
    '''

    # compute loss
    logits = logits.float() # use fp32 if logits is fp16
    with torch.no_grad():
        alpha = torch.empty_like(logits).fill_(1 - self.alpha)
        alpha[label == 1] = self.alpha

    **_probs = torch.sigmoid(logits)_**
    pt = torch.where(label == 1, probs, 1 - probs)
    ce_loss = self.crit(logits, label.double())
    loss = (alpha * torch.pow(1 - pt, self.gamma) * ce_loss)
    if self.reduction == 'mean':
        loss = loss.mean()
    if self.reduction == 'sum':
        loss = loss.sum()
    return loss

CLASStorch.nn.BCEWithLogitsLoss(weight: Optional[torch.Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean', posweight: Optional[torch.Tensor] = None)[SOURCE] This loss combines a Sigmoid layer and the BCELoss in one single class._ This version is more numerically stable than using a plain Sigmoid followed by a BCELoss as, by combining the operations into one layer, we take advantage of the log-sum-exp trick for numerical stability.

BCEWithLogitsLoss has combined a Sigmoid layer and the BCELoss in one single class, But why to use torch.sigmoid again, Anything Wrong ? thanks

CoinCheung commented 4 years ago

Because this is focal-loss, which adds a coefficient to the standard bce loss. The added coefficient is based on the sigmoid prob of the input, so we still need to compute this. Please refer to this paper for the details of the focal loss.