CoinCheung / pytorch-loss

label-smooth, amsoftmax, partial-fc, focal-loss, triplet-loss, lovasz-softmax. Maybe useful
MIT License
2.17k stars 374 forks source link

your focal loss is wrong? it seems little different with others, can you explain your code? #27

Closed Jack-zz-ze closed 2 years ago

Jack-zz-ze commented 3 years ago

class FocalLoss(nn.Module): def init(self, alpha=1, gamma=2, logits=False, reduce=True): super(FocalLoss, self).init() self.alpha = alpha self.gamma = gamma self.logits = logits self.reduce = reduce

def forward(self, inputs, targets):
    if self.logits:
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
    else:
        BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
    pt = torch.exp(-BCE_loss)
    F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

    if self.reduce:
        return torch.mean(F_loss)
    else:
        return F_loss
Jack-zz-ze commented 3 years ago

they just use forward function, and you use forward and backforward ,can you explain it

CoinCheung commented 3 years ago

Have you compared the outputs of the two implementations?

Jack-zz-ze commented 3 years ago

def forward(ctx, logits, label, alpha, gamma):         # logits = logits.float()

        probs = torch.sigmoid(logits)         coeff = (label - probs).abs().pow(gamma).neg_()         log_probs = torch.where(logits >= 0,                 F.softplus(logits, -1, 50),                 logits - F.softplus(logits, 1, 50))         log_1_probs = torch.where(logits >= 0,                 -logits + F.softplus(logits, -1, 50),                 -F.softplus(logits, 1, 50))         ce_term1 = logprobs.mul(label).mul_(alpha)         ce_term2 = log_1probs.mul(1. - label).mul_(1. - alpha)         ce = ceterm1.add(ce_term2)         loss = ce * coeff

        ctx.vars = (coeff, probs, ce, label, gamma, alpha)

        return loss

I don't observe which pt is. Can you write the formulas you want to achieve for each step?

---Original--- From: @.> Date: Wed, Aug 4, 2021 12:17 PM To: @.>; Cc: @.**@.>; Subject: Re: [CoinCheung/pytorch-loss] your focal loss is wrong? it seems little different with others, can you explain your code? (#27)

Have you compared the outputs of the two implementations?

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe. Triage notifications on the go with GitHub Mobile for iOS or Android.

CoinCheung commented 3 years ago

If you find my focal loss is wrong, please post an example code to show the difference between the correct implementation, and I will see where problem is in my code.

Jack-zz-ze commented 3 years ago

1、focal_loss = -α(1-pt)*γ log(pt) 2、your code: log_probs = torch.where(logits >= 0, F.softplus(logits, -1, 50), logits - F.softplus(logits, 1, 50)) 3、Do you want to show log(pt), but F.softplus is(1/β)log(1+e^(βx)).
log(pt)=labellog(p)+(1-label)log(1-p), but log_probs= -xlog(1+e^(-x)) + (1-x)[ x-log(1+e^(x)) ] is different

noobgrow commented 2 years ago

I have the same question, could you please the line? log_probs = torch.where(logits >= 0, F.softplus(logits, -1, 50), logits - F.softplus(logits, 1, 50))

AwalkZY commented 2 years ago

Just make some formula derivations, these implementations are totally the same and this version can be more stable. (Awesome work!) @noobgrow @Jack-zz-ze

CoinCheung commented 2 years ago

close this because this is not active anymore.