Closed Jack-zz-ze closed 2 years ago
they just use forward function, and you use forward and backforward ,can you explain it
Have you compared the outputs of the two implementations?
def forward(ctx, logits, label, alpha, gamma): # logits = logits.float()
probs = torch.sigmoid(logits) coeff = (label - probs).abs().pow(gamma).neg_() log_probs = torch.where(logits >= 0, F.softplus(logits, -1, 50), logits - F.softplus(logits, 1, 50)) log_1_probs = torch.where(logits >= 0, -logits + F.softplus(logits, -1, 50), -F.softplus(logits, 1, 50)) ce_term1 = logprobs.mul(label).mul_(alpha) ce_term2 = log_1probs.mul(1. - label).mul_(1. - alpha) ce = ceterm1.add(ce_term2) loss = ce * coeff
ctx.vars = (coeff, probs, ce, label, gamma, alpha)
return loss
I don't observe which pt is. Can you write the formulas you want to achieve for each step?
---Original--- From: @.> Date: Wed, Aug 4, 2021 12:17 PM To: @.>; Cc: @.**@.>; Subject: Re: [CoinCheung/pytorch-loss] your focal loss is wrong? it seems little different with others, can you explain your code? (#27)
Have you compared the outputs of the two implementations?
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe. Triage notifications on the go with GitHub Mobile for iOS or Android.
If you find my focal loss is wrong, please post an example code to show the difference between the correct implementation, and I will see where problem is in my code.
1、focal_loss = -α(1-pt)*γ log(pt)
2、your code:
log_probs = torch.where(logits >= 0,
F.softplus(logits, -1, 50),
logits - F.softplus(logits, 1, 50))
3、Do you want to show log(pt), but F.softplus is(1/β)log(1+e^(βx)).
log(pt)=labellog(p)+(1-label)log(1-p),
but log_probs= -xlog(1+e^(-x)) + (1-x)[ x-log(1+e^(x)) ] is different
I have the same question, could you please the line? log_probs = torch.where(logits >= 0, F.softplus(logits, -1, 50), logits - F.softplus(logits, 1, 50))
Just make some formula derivations, these implementations are totally the same and this version can be more stable. (Awesome work!) @noobgrow @Jack-zz-ze
close this because this is not active anymore.
class FocalLoss(nn.Module): def init(self, alpha=1, gamma=2, logits=False, reduce=True): super(FocalLoss, self).init() self.alpha = alpha self.gamma = gamma self.logits = logits self.reduce = reduce