pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.34k stars 6.97k forks source link

Add pos_weight in Focal Loss to trade off recall and precision #6229

Open MrShevan opened 2 years ago

MrShevan commented 2 years ago

🚀 The feature

I suggest adding the pos_weight (extra weights to positive examples) in Focal Loss implementation here. F.binary_cross_entropy_with_logits already has a pos_weight argument, so to add this feature, it's enough to create a new parameter and pass it to

ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none", pos_weight=pos_weight)

With the addition of this feature it will be possible to trade off recall and precision by adding weights to positive examples.

Motivation, pitch

It’s possible to trade off recall and precision by adding weights to positive examples. For example, if a dataset contains $100$ positive and $300$ negative examples of a single class, then pos_weight for the class should be equal to $\frac{300}{100}=3$. The loss would act as if the dataset contains $3*100 = 300$ positive examples.

This idea already used in BCEWithLogitsLoss (description) PyTorch implementation.

For Binary Focal Loss this improvement will mean the following: $$FLc = ( l{1,c}, ... , l_{N,c} )^T$$

$$ l_{n, c} = -[pc\ y{n, c}\ \alpha\ (1-\sigma(x{n,c}))^{\gamma}\ log(\sigma(x{n,c})) + (1 - y)\ (1-\alpha)\ \sigma(x{n,c})^{\gamma}\ log(1-\sigma(x{n,c})] $$

where $c$ is the class number ($c > 1$ for multi-label binary classification, $c=1$ for single-label binary classification), $n$ is the number of the sample in the batch and $p_c$ is the weight of the positive answer (pos_weight) for the class $c$.

$p_c > 1$ increases the recall, $p_c < 1$ increases the precision.

Alternatives

No response

Additional context

Ready to code it :)

datumbox commented 2 years ago

@MrShevan Thanks for the proposal.

Isn't this already factored in by the alpha parameter of the focal_loss? Our code implements it exactly as introduced at the RetinaNet paper and I'm not sure this proposal aligns with the original usage.

@fmassa if you have thoughts here I would love to hear them. Thanks!