Closed mehran66 closed 2 years ago
There are two types of focal loss here (BinaryFocalLoss and FocalLoss): https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/focal.py
Both of these functions are calling the focal_loss_with_logits function, while the second one should use softmax_focal_loss_with_logits.
Thanks for the tip, it is fixed now in 0.5.3 release.
🐛 Bug
There are two types of focal loss here (BinaryFocalLoss and FocalLoss): https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/focal.py
Both of these functions are calling the focal_loss_with_logits function, while the second one should use softmax_focal_loss_with_logits.