Open jwc-rad opened 1 year ago
Thank you for the awesome repository!
I've noticed torch.nn.CrossEntropyLoss is used for the cross entropy loss and a custom loss from utils.losses is used for the Dice loss, as used as follows: https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/train_uncertainty_aware_mean_teacher_3D.py#L124-L125
torch.nn.CrossEntropyLoss
utils.losses
The Dice loss seems to use a 'sum' reduction as follows: https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/utils/losses.py#L169-L177 However, the default reduction method for torch.nn.CrossEntropyLoss is 'mean', so the Dice loss is always roughly about H*W(*D) times bigger than the CE loss. So, a direct mean of two losses as used in the following code would not be actually the intended average. https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/train_uncertainty_aware_mean_teacher_3D.py#L171
H*W(*D)
Although I am sure this has minimal effects on most of your SSL methods because it is simply using Dice instead of Dice + CE for the supervisised loss, but still I think it should be checked.
Thank you for the awesome repository!
I've noticed
torch.nn.CrossEntropyLoss
is used for the cross entropy loss and a custom loss fromutils.losses
is used for the Dice loss, as used as follows: https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/train_uncertainty_aware_mean_teacher_3D.py#L124-L125The Dice loss seems to use a 'sum' reduction as follows: https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/utils/losses.py#L169-L177 However, the default reduction method for
torch.nn.CrossEntropyLoss
is 'mean', so the Dice loss is always roughly aboutH*W(*D)
times bigger than the CE loss. So, a direct mean of two losses as used in the following code would not be actually the intended average. https://github.com/HiLab-git/SSL4MIS/blob/30e05d80d13e093f50b07e413467f796c039a86f/code/train_uncertainty_aware_mean_teacher_3D.py#L171Although I am sure this has minimal effects on most of your SSL methods because it is simply using Dice instead of Dice + CE for the supervisised loss, but still I think it should be checked.