PaddlePaddle / PaddleVideo

Awesome video understanding toolkits based on PaddlePaddle. It supports video data annotation tools, lightweight RGB and skeleton based action recognition model, practical applications for video tagging and sport action detection.
Apache License 2.0
1.49k stars 375 forks source link

how to change entropy loss to focal loss? #241

Open luoolu opened 2 years ago

luoolu commented 2 years ago

In order to solve the problem of unbalanced categories and difficult samples

luyao-cv commented 5 months ago

To change an existing entropy loss function to Focal Loss, you'll need to modify the cross-entropy formula to incorporate the focal term, which helps in handling class imbalance by reducing the loss contribution from well-classified examples. Here's how:

Focal Loss was introduced in the paper "Focal Loss for Dense Object Detection" by Tsung-Yi Lin et al., and it modifies the cross-entropy loss as follows:

The standard cross-entropy loss for binary classification is given by:

loss = -y * log(p) - (1 - y) * log(1 - p)

where:

y is the ground-truth label (0 or 1). p is the predicted probability of the positive class. For focal loss, we add a modulating factor (1 - p)^γ to the cross-entropy:

`alpha = alpha_factor if y == 1 else 1 - alpha_factor # weighting factor for class imbalance focal_weight = alpha * (1 - p) ** gamma

focal_loss = - focal_weight (y log(p) + (1 - y) * log(1 - p))`

Here,

gamma is the focusing parameter that adjusts the rate at which easy examples are down-weighted. alpha_factor is a balancing parameter for class imbalance. For multi-class problems, you can extend this concept to each class:

import torch
import torch.nn.functional as F

def focal_loss(input, target, alpha=0.25, gamma=2.0):
    BCE_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
    pt = torch.exp(-BCE_loss)
    F_loss = alpha * (1-pt)**gamma * BCE_loss

    return F_loss.mean()

# Assuming input is logits and target is one-hot encoded
input = torch.randn(64, 10)  # batch_size x number_of_classes
target = torch.randint(0, 10, (64,)).long()  # ground-truth labels
loss = focal_loss(input, F.one_hot(target, num_classes=10))

This example shows how to implement focal loss for multi-class classification tasks using PyTorch. Note that in practice, you might want to average the loss over the batch dimension instead of summing it up, as shown above with the .mean() operation. Also, adjust alpha and gamma based on your specific task requirements.