qubvel-org / segmentation_models.pytorch

Semantic segmentation models with 500+ pretrained convolutional and transformer-based backbones.
https://smp.readthedocs.io/
MIT License
9.36k stars 1.65k forks source link

Dice Loss resulting in unexpected logit outputs #900

Open trchudley opened 1 month ago

trchudley commented 1 month ago

Hi, thanks for such a great package!

I'm currently training a simple U-net on a binary image classification problem (in this case, identifying water in satellite imagery). I am exploring different segmentation_models_pytorch loss functions to train the model. I have been using a focal loss function to initially build and test the model:

loss_function = smp.losses.FocalLoss(mode="binary", alpha=alpha, gamma=gamma)

After training the model, this model produces a relatively sensible output in both raw logit values and as a probability value after a sigmoid activation layer is applied:

focalloss

(NB: this is just a test run, the final accuracy doesn't really matter at this point)

I have been looking to explore and switch to dice loss. My understanding is that, using the from_logits variable, I could simply drop-and-replace the FocalLoss class with DiceLoss as follows:

loss_function = smp.losses.DiceLoss(mode="binary", from_logits=True)

Training the model using this DiceLoss class results in the following when applied to the same image:

diceloss

Looking at the logit output, this is great - the new dice-loss-trained model appears to qualitatively perform even better than the focal-loss-trained model! However, the raw outputs are not scaled around zero any more. Instead, raw outputs are all positive, scaled between approximately ~400 and ~9000 (depending on what image the model is being applied). As a result, applying a sigmoid activation does not create a nice probability distribution between zero and one - instead, the apparent probabilities are all now 1, due to the all-positive logit distribution.

I've examined the source code and I can't see anything that would result in such a difference. Am I missing something that results in DiceLoss not being a drag-and-drop replacement for FocalLoss to create probabilistic model predictions?

qubvel commented 1 month ago

Hi @trchudley, it should be a drop-in replacement. If from_logits=True with binary mode is specified sigmoid function is used inside the loss function. It's an interesting observation, you can see that for example notebook for binary segmentation DiceLoss function works fine and predicted logits are distributed around 0.

trchudley commented 1 month ago

Thanks @qubvel. Yes, my understanding from looking at the code was that FocalLoss accepts logit inputs, and DiceLoss could accept logits with from_logits set to True, so hopefully this lets them accept the same (logit) input.

Thanks for linking the notebook - I will have a detailed look and see whether there's anything I might be missing. I'll get back to you...