HeyLynne / FocalLoss_for_multiclass

Focal loss for multiple class classification
80 stars 17 forks source link

some modify about focal loss #4

Open Kunlei-Hong opened 1 year ago

Kunlei-Hong commented 1 year ago

class FocalLoss(nn.Module): def init(self, gamma=2, alpha=1.0, size_average=True): super(FocalLoss, self).init() self.gamma = gamma self.alpha = alpha self.size_average = size_average self.elipson = 1e-6

def forward(self, logits, labels,num_classes):
    label_onehot = F.one_hot(labels,num_classes=num_classes)
    log_p = F.log_softmax(logits,dim=-1)
    ce_loss = (log_p * label_onehot).sum(1) + self.elipson
    p = F.softmax(logits, dim=1)
    pt = (label_onehot * p).sum(1) + self.elipson
    sub_pt = 1 - pt
    l = -self.alpha * (sub_pt**self.gamma) * ce_loss
    if self.size_average:
        return l.mean()
    else:
        return l.sum()

++++++++++++++++++++++++ it seems like this

HeyLynne commented 1 year ago

Thx I'll test it