dsgiitr / d2l-pytorch

This project reproduces the book Dive Into Deep Learning (https://d2l.ai/), adapting the code from MXNet into PyTorch.
Apache License 2.0
4.25k stars 1.24k forks source link

Questions on the Focal Loss computation in the SSD Chapter #117

Open aurotripathy opened 3 years ago

aurotripathy commented 3 years ago

Two questions

  1. alpha_t is unused below.
  2. should we use the following to derive the BCECrossEntropyLoss:

    pt = torch.log(p) * target.float() + torch.log(1.0 - p) * (1 - target).float()

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, device="cuda:0", eps=1e-10):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.device = device
        self.eps = eps

    def forward(self, input, target):
        p = torch.sigmoid(input)
        pt = p * target.float() + (1.0 - p) * (1 - target).float()
        alpha_t = (1.0 - self.alpha) * target.float() + self.alpha * (1 - target).float()
        loss = - 1.0 * torch.pow((1 - pt), self.gamma) * torch.log(pt + self.eps)
        return loss.sum()
aurotripathy commented 3 years ago

Maintainers, kindly respond (or provide hint)