HeyLynne / FocalLoss_for_multiclass

Focal loss for multiple class classification
79 stars 16 forks source link

Potential BUG #3

Open ramsever opened 2 years ago

ramsever commented 2 years ago

There is a bug in the focal loss function

The following line is incorrect:

pt = label_onehot * log_p

it should be:

pt = label_onehot * p

Where p is the probability and not the log probability

Please confirm.

FT115 commented 2 years ago

I think you are right! When I use log probability I get Nan loss, but when I use probability (p = F.softmax(logits)), my loss looks normal!

HeyLynne commented 1 year ago

Thanks. Will fix on it

aliwaqas333 commented 1 year ago

pt = label_onehot * torch.exp(log_p)