Open burhr2 opened 2 years ago
Hi Burhan, Thank you very much for your question and it sounds like an interesting project! It is surprising because I would not expect differences between using sigmoid or softmax as the last activation, given that softmax with two outputs is effectively equivalent to using a sigmoid activation.
Just a few things I would like to check first:
Unrelated to loss functions, two suggestions that come to mind that might help with performance:
Best wishes, Michael
Hello, Thanks for the excellent repo you have put together. We are working on a 3D binary segmentation task for detecting lesions in spinal cord MRI images. we have a situation of class imbalance with a lesion(foreground class) far less represented than the background class (Proportional of foreground voxels per training patches(48x48x48) is 0.9%, i.e. average for patches with lesions). We are using a 3Dunet model with sigmoid output, which works well. When updating the 3Dunet with softmax, there is a tendency for many false positive predictions compared to sigmoid output. We train on randomly selected patches, so we can easily have training patches with the background class only. Can you give insight, or maybe we are doing something wrong (please see the updated softmax output code below). My intuition is that since the background class cover a large proportion of voxels, there is a tendency for the model to learn the background class more than the foreground, even with the penalties from the losses. For example, using asymmetric_focal_loss resulted in a model predicting only the background class. Another example can be a dice_coeffient calculated per class and returning the average dice. This seems to greatly influence the good dice we get from the background class compared to the foreground.
Lesion level results with default parameters TP = true positive FP = false positive FN = false negative GT = number of lesions in the ground truth
No.
Loss
TP
FP
FN
GT
1
Asymetric_unified_loss
31
580
26
57
2
Symetric_unified_loss
23
183
34
57
3
asymmetric_focal_tversky_loss
29
358
28
57
4
asymmetric_focal_loss
0
1166
57
57
5
symmetric_focal_tversky_loss
0
0
57
57
6
tversky_loss
24
578
33
57
7
combo_loss
22
267
35
57
8
focal_tversky_loss
27
281
30
57
9
focal_loss
7
48
50
57
10
symmetric_focal_loss
0
923
57
57
11
dice_loss
31
382
26
57
The input image and the two-channel mask
3Dunet
During Inference