HiLab-git / SSL4MIS

Semi Supervised Learning for Medical Image Segmentation, a collection of literature reviews and code implementations.
MIT License
2.13k stars 381 forks source link

DiceLoss uses 'sum' reduction, but CrossEntropyLoss uses 'mean' reduction #74

Open jwc-rad opened 1 year ago

jwc-rad commented 1 year ago

Thank you for the awesome repository!

I've noticed torch.nn.CrossEntropyLoss is used for the cross entropy loss and a custom loss from utils.losses is used for the Dice loss, as used as follows: https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/train_uncertainty_aware_mean_teacher_3D.py#L124-L125

The Dice loss seems to use a 'sum' reduction as follows: https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/utils/losses.py#L169-L177 However, the default reduction method for torch.nn.CrossEntropyLoss is 'mean', so the Dice loss is always roughly about H*W(*D) times bigger than the CE loss. So, a direct mean of two losses as used in the following code would not be actually the intended average. https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/train_uncertainty_aware_mean_teacher_3D.py#L171

Although I am sure this has minimal effects on most of your SSL methods because it is simply using Dice instead of Dice + CE for the supervisised loss, but still I think it should be checked.