Open ramsever opened 2 years ago
There is a bug in the focal loss function
The following line is incorrect:
pt = label_onehot * log_p
it should be:
pt = label_onehot * p
Where p is the probability and not the log probability
Please confirm.
I think you are right! When I use log probability I get Nan loss, but when I use probability (p = F.softmax(logits)), my loss looks normal!
Thanks. Will fix on it
pt = label_onehot * torch.exp(log_p)
There is a bug in the focal loss function
The following line is incorrect:
pt = label_onehot * log_p
it should be:
pt = label_onehot * p
Where p is the probability and not the log probability
Please confirm.