pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.11k stars 6.94k forks source link

New Feature: Dice Loss #6435

Open oke-aditya opened 2 years ago

oke-aditya commented 2 years ago

🚀 The feature

Followup to #6323

Addition of Dice Loss to torchvision.

Motivation, pitch

Mainly Dice loss is used for semantic segmentation.

I want to understand the technical aspects of adding it to torchvision. Are we going to support boolean tensors or outputs from the semantic segmentation models?

Few references.

https://github.com/pytorch/pytorch/issues/1249#issuecomment-305088398

https://github.com/rogertrullo/pytorch/blob/rogertrullo-dice_loss/torch/nn/functional.py#L708

MONOAI

https://docs.monai.io/en/stable/_modules/monai/losses/dice.html#DiceLoss

https://kornia.readthedocs.io/en/latest/_modules/kornia/losses/dice.html#dice_loss

Alternatives

No response

Additional context

No response

datumbox commented 2 years ago

@oke-aditya Just wanted to follow up on this. Is this a feature you still intend to build? Thanks!

oke-aditya commented 2 years ago

yes once I finish swin transformer 3d over the weekend. Next would be this.

datumbox commented 2 years ago

That's awesome! I didn't know you were prioritizing Swin 3d. Sounds awesome! No rush!

pri1311 commented 1 year ago

Could I give it a try? Seems like a good issue to start with.

oke-aditya commented 1 year ago

Sure. I can help you with this.

@datumbox this was the next one on my plate and I will help @pri1311 on this

@pri1311 Feel free to reach out if you need any help. E.g. setting your development environment etc.

pri1311 commented 1 year ago

Thank you @oke-aditya! A question - we are looking to implement multiclass dice loss, correct?

oke-aditya commented 1 year ago

Well I was originally thinking dice loss is binary classification for semantic segmentation models. Where you want to separate background and the object.

I think our semantic segmentation models output is a logit score for each pixel belong to a class. In a sense semantic segmentation models are not multi class. They distinguish object vs background. Refer https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html#semantic-segmentation-models

A multi class case can be combined and handled with binary classification loss. You can do one vs rest, for each class and combine the losses.

My reference implementation https://github.com/oke-aditya/quickvision/blob/master/quickvision/losses/functional/dice_loss.py

Additionally Read for discussion with multi class dice loss. https://github.com/keras-team/keras/issues/9395

cc @datumbox as he has great experience in semantic segmentation models and he can tell you definite answer.

pri1311 commented 1 year ago

Ahh alright, shouldn't be too hard to add binary dice loss.

In the discussion here - https://github.com/pytorch/pytorch/issues/1249#issuecomment-305088398 there were mentions of multiclass dice loss, also I think Kornia supports multiclass dice loss. So just wanted to confirm once.