isaaccorley / torchseg

Segmentation models with pretrained backbones. PyTorch.
MIT License
104 stars 8 forks source link

Can I mixed loss? #6

Closed PeterKim1 closed 9 months ago

PeterKim1 commented 9 months ago

Hello. Thanks for your great works.

I want to use mixed loss(for example, dice loss + focal loss + etc...), but in segmentations_models_pytorch can't use this.

So, I think this function need to add.

Could you add mixed loss function?

Thanks.

isaaccorley commented 9 months ago

Hi @PeterKim1 there are many possible combinations of the loss functions in this library and it seems it would be specific to each user. I definitely tend to use a CE + Jaccard loss in my projects!

However, I don't think it would be good to add since it would be more code to maintain.

Creating a multi-objective loss isn't too much effort though, see the example below. Hope this helps!

class FocalDiceLoss(nn.Module):
   def __init__(self, ...):
      self.focal_loss = FocalLoss(...)
      self.dice_loss = DiceLoss(...)
      self.alpha = 0.5
      self.beta = 0.5

   def forward(self, preds, targets):
      focal_loss = self.alpha * self.focal_loss(preds, targets)
      dice_loss = self.beta * self.dice_loss(preds, targets)
      loss = focal_loss + dice_loss
      return loss
PeterKim1 commented 9 months ago

@isaaccorley Thanks for your kind reply!

isaaccorley commented 9 months ago

@PeterKim1 Of course!