mlyg / unified-focal-loss

Apache License 2.0
150 stars 22 forks source link

How to give more weights to a certain label #14

Closed Minseok-Sakong closed 2 years ago

Minseok-Sakong commented 2 years ago

Hello mlyg,

First of all, thank you so much for your awesome work. I am learning a lot from your work!

I am trying to segment 3d datasets with multi-classes using 3D Unet. Classes are background, foreground, and edge. Also, the datasets have a class imbalance problem, where background>>foreground>edge.

I tried segmenting the datasets with focal tversky loss. It handled the class imbalance problem pretty well, but it still had some problems. The most important label to segment is the edge label. However, with just the focal tversky loss, some regions that need to be segmented as edge are predicted as the foreground. In summary, a false positive(FP) of an edge label is okay, but a false negative(FN) should be severely penalized. Would there be a way to modify the loss function to solve this problem?

Minseok-Sakong commented 2 years ago

image This is an example of my dataset.

mlyg commented 2 years ago

Hi Minseok,

I am very glad you have found work useful!

Segmenting edges is often very challenging. As you mentioned, one thing worth trying is to alter the class weights, giving a greater contribution of the edge label loss to the overall loss. I added a slight modification to the code you sent before to include weights:

def categorical_focal_tversky_loss(delta=0.7, gamma=0.75, weights=[1.,1.,1.]):
    """This is the implementation for multiclass segmentation.
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : float, optional
        focal parameter controls degree of down-weighting of easy examples, by default 0.75
    weights : list, optional
        controls weight given to each class
    """
    def loss_function(y_true, y_pred):
        # Clip values to prevent division by zero error
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

        axis = identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)     
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1-y_pred), axis=axis)
        fp = K.sum((1-y_true) * y_pred, axis=axis)
        dice_class = (tp + epsilon)/(tp + delta*fn + (1-delta)*fp + epsilon)

        #calculate losses separately for each class, enhancing both classes
        back_dice = (1-dice_class[:,0]) * K.pow(1-dice_class[:,0], -gamma) * weights[0]
        edge_dice = (1-dice_class[:,1]) * K.pow(1-dice_class[:,1], -gamma) * weights[1]
        brain_dice = (1-dice_class[:,2]) * K.pow(1-dice_class[:,2], -gamma) * weights[2]

        # Average class scores
        loss = K.mean(tf.stack([back_dice,edge_dice, brain_dice],axis=-1))
        return loss

    return loss_function

You can try adding more weight to edge loss relative to the other labels and seeing whether that helps.

It is also possible to change the loss to penalise the false positives/negatives differently for each class, and below I have written an implementation that allows you to set delta, gamma and the weights for each class separately:

def categorical_focal_tversky_loss(delta=[0.7,0.7,0.7], gamma=[0.75,0.75,0.75], weights=[1.,1.,1.]):
    """This is the implementation for multiclass segmentation.
    Parameters
    ----------
    delta : array, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : array, optional
        focal parameter controls degree of down-weighting of easy examples, by default 0.75
    weight: array, optional
        controls weight given to each class, by default 1. (equal weighting)

    array values are for [background, edge, brain] respectively

    """
    # convert lists to arrays for vectorised operations
    delta = np.array(delta)
    gamma = np.array(gamma)
    weights = np.array(weights)

    def loss_function(y_true, y_pred):
        # Clip values to prevent division by zero error
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

        axis = identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)     
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1-y_pred), axis=axis)
        fp = K.sum((1-y_true) * y_pred, axis=axis)

        tversky_class = (tp + epsilon)/(tp + delta*fn + (1-delta)*fp + epsilon)

        #calculate losses for each class
        class_loss = (1-tversky_class) * K.pow(1-tversky_class, -gamma) * weights

        # Average class scores
        loss = K.mean(class_loss,axis=-1)[0]

        return loss

    return loss_function

I used a vectorised implementation for this one to make the code shorter and a little faster, although less explicit. There is a risk of having too many hyperparameters here to tune. I think the main ones you want to change are increasing delta for the edge loss (penalise false negatives more heavily, maybe try 0.8/0.9), and increasing the weight of the edge loss.

I hope this helps!

Minseok-Sakong commented 2 years ago

Hello mlyg,

Thank you so much for the reply. I tested the loss function where I can modify both the deltas and the weights. However, the training loss does not decrease at a certain point. For example, when I set up the delta as [0.7,0.8,0.7] the training loss starts from 0.9 and decrease to 0.7, but after that the training loss just fluctuates in the range 0.7.

Also, I tried predicting with the trained model, which the loss stayed the same in the range of 0.7, and the model predicts just the label 0,1 (just the background and edge). I have tried several modifications to the hyper parameters, but they all end up at the loss 0.7 ish, and the model fails to learn the dataset properly.

Would there be a way to improve this issue?

mlyg commented 2 years ago

Hi Minseok,

I am sorry to hear that your problem is still not solved.

Just to check first, do you get similar results with the default hyperparameters in the above implementation as before? I want to make sure there isn't an implementation problem, as the above implementation should give the same results if the hyperparameters are set the same.

Minseok-Sakong commented 2 years ago

Hello mlyg,

Thank you for the reply. Unfortunately, I tried with the default hyperparameters in this loss function, and the problem still exists. The training loss value does not decrease below 0.7. I am afraid it might be the implementation problem, because with the default focal_tversky_loss, my model learned the datasets well. def categorical_focal_tversky_loss(delta=[0.7,0.7,0.7], gamma=[0.75,0.75,0.75], weights=[1.,1.,1.]):

mlyg commented 2 years ago

Hi Minseok,

It is very difficult for me to give any useful advice because I do not know what the project is about or look at the implementation. The reason for the poor performance could be an implementation problem, it could be the way the problem was defined (e.g. learning edges is a hard problem), or it could be a problem with the loss function used.

Without being able to rule out the other, more fundamental issues, I am afraid I am not sure how much I can help with this issue from a loss function point of view. I will close this thread here to keep the information here relevant to loss functions.

This was a very useful issue to raise, as users can find here the implementation where they can change the weights assigned to the classes. Thank you very much.

You can contact me directly through my email (my1721@imperial.ac.uk) if you would like to discuss more specifics related to your project.