keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.64k stars 19.42k forks source link

Generalized dice loss for multi-class segmentation #9395

Closed emoebel closed 5 years ago

emoebel commented 6 years ago

Hey guys, I just implemented the generalised dice loss (multi-class version of dice loss), as described in ref : (my targets are defined as: (batch_size, image_dim1, image_dim2, image_dim3, nb_of_classes))

def generalized_dice_loss_w(y_true, y_pred): 
    # Compute weights: "the contribution of each label is corrected by the inverse of its volume"
    Ncl = y_pred.shape[-1]
    w = np.zeros((Ncl,))
    for l in range(0,Ncl): w[l] = np.sum( np.asarray(y_true[:,:,:,:,l]==1,np.int8) )
    w = 1/(w**2+0.00001)

    # Compute gen dice coef:
    numerator = y_true*y_pred
    numerator = w*K.sum(numerator,(0,1,2,3))
    numerator = K.sum(numerator)

    denominator = y_true+y_pred
    denominator = w*K.sum(denominator,(0,1,2,3))
    denominator = K.sum(denominator)

    gen_dice_coef = numerator/denominator

    return 1-2*gen_dice_coef

But something must be wrong. I'm working with 3D images that I have to segment for 4 classes (1 background class and 3 object classes, I have a imbalanced dataset). First odd thing: while my train loss and accuracy improve during training (and converge really fast), my validation loss/accuracy are constant trough epochs (see image). Second, when predicting on test data, only the background class is predicted: I get a constant volume.

I used the exact same data and script but with categorical cross-entropy loss and get plausible results (object classes are segmented). Which means something is wrong with my implementation. Any idea what it could be?

Plus I believe it would be usefull to the keras community to have a generalised dice loss implementation, as it seems to be used in most of recent semantic segmentation tasks (at least in the medical image community).

PS: it seems odd to me how the weights are defined; I get values around 10^-10. Anyone else has tried to implement this? I also tested my function without the weights but get same problems.

Tarandeep97 commented 2 years ago

@gattia

@lazyleaf, I just stumbled upon this. I am doing 3D segmentation on multiclass. I will definitely try out the proposed method and see how it works. However, I also have another solution that has worked for me in the past:

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_multilabel(y_true, y_pred, numLabels=5):
    dice=0
    for index in range(numLabels):
        dice -= dice_coef(y_true[:,index,:,:,:], y_pred[:,index,:,:,:])
    return dice

This simply calculates the dice score for each individual label, and then sums them together, and includes the background. The best dice score you will ever get is equal to numLables*-1.0. When monitoring I always keep in mind that the dice for the background is almost always near 1.0.

Won't there been a weight for each label multiplied to dice_coef, weight that depicts the number of pixels for that label ?

gattia commented 2 years ago

What you are describing is one version of the DSC that has been called generalized-dsc.

I think they inversely weighted based on the number of pixels labeled that class.

weight = 1/ sum(y_true[:,index,:,:,:])

dice -= weight * dice_coef(y_true[:,index,:,:,:], y_pred[:,index,:,:,:])

the above would replace the code in the for loop to inversely weight it based on the number of pixels labeled that particular class (in the ground truth).