qubvel-org / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.39k stars 1.65k forks source link

How to specify class weights for loss function? #881

Open tanmay4269 opened 3 months ago

tanmay4269 commented 3 months ago

How to set class weights for different loss functions? Like Jaccard Loss

qubvel commented 3 months ago

At the moment this is not supported out-fo-the-box, but you can do it by computing loss function per class and averaging.

Here is a pseudo-code:

n_classes = 7
class_weights = torch.tensor([1, 1, 1, 1, 1, 1, 1])

predicted_tensor = torch.ones([5, n_classes, 512, 512])
gt_tensor = torch.ones([5, n_classes, 512, 512])

criterion = smp.losses.JaccardLoss("binary")

loss = 0
for i in range(n_classes):
  class_prediction = predicted_tensor[:, i].unsqueeze(1) # B, 1, H, W
  class_gt = gt_tensor[:, i].unsqueeze(1) # B, 1, H, W
  class_weight = class_weights[i]
  class_loss = criterion(class_prediction, class_gt)
  loss += (class_loss * class_weight)