axondeepseg / sam_myelin_seg_TEM

Axon and Myelin segmentation using FAIR's Segment-Anything-Model (SAM)
MIT License
0 stars 0 forks source link

Add regularization to loss #4

Open hermancollin opened 1 year ago

hermancollin commented 1 year ago

I would like to add some regularization to the loss function for robustness to discourage the model to produce "glitchy" segmentations. For a perfect illustration, see the image below, taken from the validation set of https://github.com/brainhack-school2023/collin_project/tree/main (first iteration of this project).

Screenshot_20230722_142152

  1. Pink axon has discontinuities.
  2. Brown axon is not complete

I am not yet entirely sure how to regularize the myelin prediction, but will update this issue later.

hermancollin commented 1 year ago

As a sidenote, SAM was trained with a compound loss: Focal + Dice + MSE. MSE is not computed on the mask, but on the predicted IoU and actual IoU of the mask. Maybe using a similar loss function to the one originally used could be useful.

We supervise mask prediction with a linear combination of focal loss [65] and dice loss [73] in a 20:1 ratio of focal loss to dice loss, following [20, 14]. [...] The IoU prediction head is trained with mean-square-error loss between the IoU prediction and the predicted mask’s IoU with the ground truth mask. It is added to the mask loss with a constant scaling factor of 1.0.

This could be implemented with monai.losses.DiceFocalLoss with lambda_focal=20.