Open hermancollin opened 1 year ago
As a sidenote, SAM was trained with a compound loss: Focal + Dice + MSE. MSE is not computed on the mask, but on the predicted IoU and actual IoU of the mask. Maybe using a similar loss function to the one originally used could be useful.
We supervise mask prediction with a linear combination of focal loss
[65]and dice loss[73]in a 20:1 ratio of focal loss to dice loss, following[20, 14]. [...] The IoU prediction head is trained with mean-square-error loss between the IoU prediction and the predicted mask’s IoU with the ground truth mask. It is added to the mask loss with a constant scaling factor of 1.0.
This could be implemented with monai.losses.DiceFocalLoss
with lambda_focal=20
.
I would like to add some regularization to the loss function for robustness to discourage the model to produce "glitchy" segmentations. For a perfect illustration, see the image below, taken from the validation set of https://github.com/brainhack-school2023/collin_project/tree/main (first iteration of this project).
I am not yet entirely sure how to regularize the myelin prediction, but will update this issue later.