mlyg / unified-focal-loss

Apache License 2.0
146 stars 22 forks source link

High Number of False positives for Binary class with Softmax Layer #15

Open burhr2 opened 1 year ago

burhr2 commented 1 year ago

Hello, Thanks for the excellent repo you have put together. We are working on a 3D binary segmentation task for detecting lesions in spinal cord MRI images. we have a situation of class imbalance with a lesion(foreground class) far less represented than the background class (Proportional of foreground voxels per training patches(48x48x48) is 0.9%, i.e. average for patches with lesions). We are using a 3Dunet model with sigmoid output, which works well. When updating the 3Dunet with softmax, there is a tendency for many false positive predictions compared to sigmoid output. We train on randomly selected patches, so we can easily have training patches with the background class only. Can you give insight, or maybe we are doing something wrong (please see the updated softmax output code below). My intuition is that since the background class cover a large proportion of voxels, there is a tendency for the model to learn the background class more than the foreground, even with the penalties from the losses. For example, using asymmetric_focal_loss resulted in a model predicting only the background class. Another example can be a dice_coeffient calculated per class and returning the average dice. This seems to greatly influence the good dice we get from the background class compared to the foreground.

Lesion level results with default parameters TP = true positive FP = false positive FN = false negative GT = number of lesions in the ground truth

No. Loss   TP FP FN GT
1 Asymetric_unified_loss   31 580 26 57
2 Symetric_unified_loss   23 183 34 57
3 asymmetric_focal_tversky_loss   29 358 28 57
4 asymmetric_focal_loss   0 1166 57 57
5 symmetric_focal_tversky_loss   0 0 57 57
6 tversky_loss   24 578 33 57
7 combo_loss   22 267 35 57
8 focal_tversky_loss   27 281 30 57
9 focal_loss   7 48 50 57
10 symmetric_focal_loss   0 923 57 57
11 dice_loss   31 382 26 57

The input image and the two-channel mask

# sigmoid version
input_img  = (1, 48, 48, 48, 1)
single_channel_mask = (1, 48, 48, 48, 1)

# Softmax version
two_channel_mask = tensorflow.keras.utils.to_categorical (single_channel_mask)

# inputs of the model
input_img  = (1, 48, 48, 48, 1)
two_channel_mask = (1, 48, 48, 48, 2) # 1st channel for background 2nd channel for foreground

3Dunet

# Define the global variables
KERNEL_SIZE = (3, 3, 3)
POOLING_SIZE = (2, 2, 2)
FILTERS = [16, 32, 64]
shape = (48, 48, 48, 1)
depth = 2

def unet3D_softmax(num_classes = 2):
    """Whole Unet architecture from the predefined blocks"""

    input = tf.keras.layers.Input(shape=shape)
    layer = input
    hist = []
    for i in range(depth):
        (layer, save) = get_down_block(i, layer, dropout=dropout)
        hist.append(save)
    layer = tf.keras.layers.Conv3D(FILTERS[depth], KERNEL_SIZE, padding="same")(
        layer
    )  
    layer = tf.keras.layers.BatchNormalization()(layer)
    layer = tf.keras.layers.Activation("relu")(layer)
    layer = tf.keras.layers.Conv3D(FILTERS[depth] * 2, KERNEL_SIZE, padding="same")(
        layer
    )  
    layer = tf.keras.layers.BatchNormalization()(layer)
    layer = tf.keras.layers.Activation("relu")(layer)
    for i in reversed(range(depth)):
        layer = get_up_block(layer, hist[i], i, dropout=dropout)
    layer = tf.keras.layers.Dropout(dropout)(layer)

    if num_classes == 1:  #Binary
      activation = 'sigmoid'
    else:
      activation = 'softmax'

    layer = tf.keras.layers.Conv3D(num_classes, 1, padding="same", activation=activation)(layer)
    model = tf.keras.Model(inputs=input, outputs=layer)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    model.compile(
        optimizer=optimizer, loss=asym_unified_focal_loss(), metrics=[dice_coefficient()]
    ) 

    return model

During Inference

predictions_list = [ ] 
for patches in range(test_image_patches):
     single_patch_prediction = model.predict(patches)
     # shape of prediction 1,48,48,48,2 output probabilities
     single_patch_prediction_argmax = np.argmax(single_patch_prediction, axis=4) # output 1,48,48,48
     single_patch_prediction_argmax = np.expand_dims(single_patch_prediction_argmax, axis = -1) # output 1,48,48,48,1 for compactibility with our pipeline
     predictions_list.append(single_patch_prediction_argmax)
mlyg commented 1 year ago

Hi Burhan, Thank you very much for your question and it sounds like an interesting project! It is surprising because I would not expect differences between using sigmoid or softmax as the last activation, given that softmax with two outputs is effectively equivalent to using a sigmoid activation.

Just a few things I would like to check first:

  1. What are the results you are getting with the sigmoid layer?
  2. Have you tried the same loss function (although perhaps requiring different code) with sigmoid and softmax layers separately?
  3. Have you repeated any of these experiments or used cross validation (same loss function and activation)? It might be useful to see how the performance varies despite using the same parameters.

Unrelated to loss functions, two suggestions that come to mind that might help with performance:

  1. The U-Net depth used is quite low, which might cause problems with learning this difficult task. I wonder if you have tried using higher depths like 3 or 4?
  2. With patch-based predictions, I have found it useful to making overlapping predictions, and averaging over the activations (i.e. probabilities) and then applying argmax. I found this very useful particularly for reducing false positives. The functions to do that can be found in: https://github.com/frankkramer-lab/MIScnn/blob/master/miscnn/utils/patch_operations.py

Best wishes, Michael