Open MrShevan opened 2 years ago
@MrShevan Thanks for the proposal.
Isn't this already factored in by the alpha parameter of the focal_loss? Our code implements it exactly as introduced at the RetinaNet paper and I'm not sure this proposal aligns with the original usage.
@fmassa if you have thoughts here I would love to hear them. Thanks!
🚀 The feature
I suggest adding the
pos_weight
(extra weights to positive examples) in Focal Loss implementation here. F.binary_cross_entropy_with_logits already has apos_weight
argument, so to add this feature, it's enough to create a new parameter and pass it toWith the addition of this feature it will be possible to trade off recall and precision by adding weights to positive examples.
Motivation, pitch
It’s possible to trade off recall and precision by adding weights to positive examples. For example, if a dataset contains $100$ positive and $300$ negative examples of a single class, then
pos_weight
for the class should be equal to $\frac{300}{100}=3$. The loss would act as if the dataset contains $3*100 = 300$ positive examples.This idea already used in BCEWithLogitsLoss (description) PyTorch implementation.
For Binary Focal Loss this improvement will mean the following: $$FLc = ( l{1,c}, ... , l_{N,c} )^T$$
$$ l_{n, c} = -[pc\ y{n, c}\ \alpha\ (1-\sigma(x{n,c}))^{\gamma}\ log(\sigma(x{n,c})) + (1 - y)\ (1-\alpha)\ \sigma(x{n,c})^{\gamma}\ log(1-\sigma(x{n,c})] $$
where $c$ is the class number ($c > 1$ for multi-label binary classification, $c=1$ for single-label binary classification), $n$ is the number of the sample in the batch and $p_c$ is the weight of the positive answer (
pos_weight
) for the class $c$.$p_c > 1$ increases the recall, $p_c < 1$ increases the precision.
Alternatives
No response
Additional context
Ready to code it :)