keras-team / keras-cv

Industry-strength Computer Vision workflows with Keras
Other
1.01k stars 330 forks source link

DICE loss #296

Closed innat closed 1 year ago

innat commented 2 years ago

It is one of the most used loss functions in semantic segmentation tasks. Unfortunately still not available in keras.loss.*. It has been asked many times, #3611, #13085, #9395, #10890.

(Not sure, if it's fit here or keras.loss.*)

bhack commented 2 years ago

https://github.com/tensorflow/addons/pull/2558#issuecomment-941220370

innat commented 2 years ago

Okayyy, someone pushed. I think this function should be available here. It's widely used. cc. @gamenerd457

gamenerd457 commented 2 years ago

Interested

bhack commented 2 years ago

@LukeWood If you want this please assign it to @gamenerd457 to port his PR here.

innat commented 2 years ago

@LukeWood @qlzh727 Could u please give some feedback on this and also keras-team/keras-cv#341?

qlzh727 commented 2 years ago

Yea, I think we can host this here (adding this to keras repo might need more extended API review and requirement). We should put this to keras_cv/losses folder.

bhack commented 2 years ago

@gamenerd457 Green light for a PR here.

LukeWood commented 2 years ago

yeah seems like a great fit!

gamenerd457 commented 2 years ago

So should I make a pr here

LukeWood commented 2 years ago

So should I make a pr here

If the loss easily confirms to the Keras loss API (y_true, y_pred) then yes!

bhack commented 2 years ago

@LukeWood The intention was to port/refactoring @gamenerd457's Addons PR https://github.com/tensorflow/addons/pull/2558/files

innat commented 2 years ago

@gamenerd457 (cc. @bhack @LukeWood @qlzh727 )

If the loss easily confirms to the Keras loss API (y_true, y_pred) then yes!

Some refactoring may be needed, (i.e. subclassing the tf.keras.losses.Loss class). I think you can start working on it, either you can move PR from tf-addons or create a new one here. The implementation confirms the keras loss API.

Also please follow the implementation details from segmentation_models/losses.py. This reference implementation offers some useful initial parameter (i.e per_image, class_weights etc), which I found very useful in practice.


def dice(y_true, y_pred, ...):
    return loss

class Dice(keras.losses.Loss):
    def __init__(self, 
                beta=1, 
                class_weights=None, 
                class_indexes=None, 
                per_image=False, smooth=SMOOTH, name='Dice'):
        super().__init__(name=name)

    def call(self, y_true, y_pred):
        return dice(...)
LukeWood commented 2 years ago

What are some use cases for per_image?

Also note that in keras we won't support class_weights; instead, we will support sample_weights in call.

innat commented 2 years ago

From HERE.

per_image: If True loss is calculated for each image in batch and then averaged, else loss is calculated for the whole batch.

Enabling it to True provides a performance boost depending on the task.

LukeWood commented 2 years ago

Cool, we will need to include this in the docstring. Seems like a match!

innat commented 1 year ago

@DavidLandup0

Here are some info that might be helpful.