mlyg / unified-focal-loss

Apache License 2.0
150 stars 22 forks source link

Can we use unified loss function for multiclass segmentation? #8

Closed xlar-sanjeet closed 2 years ago

mlyg commented 2 years ago

Absolutely! The easiest way is to modify the code where the scores are calculated for the individual classes. To do this, you need to have a one-hot encoding of the classes, and know which classes correspond to each axis. For example, there are three classes in the KiTS19 dataset example (background, kidney and kidney tumour - corresponding to axes 0, 1 and 2 respectively), where the kidney and kidney tumour are much smaller than the background, and so the asymmetric Focal and asymmetric Focal Tversky losses can be modified to:

################################
#     Asymmetric Focal loss    #
################################
def asymmetric_focal_loss(delta=0.7, gamma=2.):
    def loss_function(y_true, y_pred):
    """For Imbalanced datasets
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : float, optional
        Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
    """
        axis = identify_axis(y_true.get_shape())  

        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
        cross_entropy = -y_true * K.log(y_pred)

        # calculate losses separately for each class, only suppressing background class 
        # modify this section below for multiclass segmentation - note the additional axis added to account for 3D 
        segmentation
        back_ce = K.pow(1 - y_pred[:,:,:,:,0], gamma) * cross_entropy[:,:,:,:,0]
        back_ce =  (1 - delta) * back_ce

        kidney_ce = cross_entropy[:,:,:,:,1]
        kidney_ce = delta * fore_ce

        tumour_ce = cross_entropy[:,:,:,:,2]
        tumour_ce = delta * fore_ce

        loss = K.mean(K.sum(tf.stack([back_ce, kidney_ce, tumour_ce],axis=-1),axis=-1))

        return loss

    return loss_function

#################################
# Asymmetric Focal Tversky loss #
#################################
def asymmetric_focal_tversky_loss(delta=0.7, gamma=0.75):
    """This is the implementation for binary segmentation.
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.7
    gamma : float, optional
        focal parameter controls degree of down-weighting of easy examples, by default 0.75
    """
    def loss_function(y_true, y_pred):
        # Clip values to prevent division by zero error
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

        axis = identify_axis(y_true.get_shape())
        # Calculate true positives (tp), false negatives (fn) and false positives (fp)     
        tp = K.sum(y_true * y_pred, axis=axis)
        fn = K.sum(y_true * (1-y_pred), axis=axis)
        fp = K.sum((1-y_true) * y_pred, axis=axis)
        dice_class = (tp + epsilon)/(tp + delta*fn + (1-delta)*fp + epsilon)

        # calculate losses separately for each class, only enhancing foreground class
        # modify this section below for multiclass segmentation
        back_dice = (1-dice_class[:,0]) 
        kidney_dice = (1-dice_class[:,1]) * K.pow(1-dice_class[:,1], -gamma) 
        tumour_dice = (1-dice_class[:,2]) * K.pow(1-dice_class[:,2], -gamma)

        # Average class scores
        loss = K.mean(tf.stack([back_dice, kidney_dice, tumour_dice],axis=-1))
        return loss

    return loss_function
xlar-sanjeet commented 2 years ago

Thank you

On Thu, 14 Apr 2022 at 15:18, mlyg @.***> wrote:

Closed #8 https://github.com/mlyg/unified-focal-loss/issues/8.

— Reply to this email directly, view it on GitHub https://github.com/mlyg/unified-focal-loss/issues/8#event-6433301466, or unsubscribe https://github.com/notifications/unsubscribe-auth/AUKYQWTCQUKVYFQDTMX2VVTVE7SXDANCNFSM5K33XJ4A . You are receiving this because you authored the thread.Message ID: @.***>

pedrohbd commented 1 year ago

Hi @mlyg. Could you update your code to add this multiclass version? It has been some time since your paper and I guess lots of people would benefit from a multiclass version, but I can't find one anywhere besides this issue. Thanks!

mlyg commented 1 year ago

@pedrohbd Thanks for bring this up - it makes a lot of sense to add the multiclass versions to the main code. I will look to get this done as soon as I can.

Best wishes, Michael