MIC-DKFZ / nnUNet

Apache License 2.0
5.67k stars 1.71k forks source link

Focal Loss Implementation Error #705

Closed joeranbosma closed 2 years ago

joeranbosma commented 3 years ago

When investigating training with different loss functions, I encountered a problem with Focal loss resulting in no network training at all. I've been able to pinpoint the issue to an implementation error in the FocalLossMultiClass class in nnunet.training.network_training.nnUNet_variants.loss_function.nnUNetTrainerV2_focalLoss.

Issue:
The line logits = logits.view(-1, num_classes) shuffles the logits in an unexpected way, and as a result, they do not align with the ground truth label anymore. If helpful, I can include plots of the logits and voxel-wise loss, which clearly show the logits are not reshaped as intended.

Potential solution:
Changing two lines seems to resolve the unexpected reshuffling:

  1. change the reshaping to logits = logits.view(num_classes, -1) in this line: https://github.com/MIC-DKFZ/nnUNet/blob/627c2473a0c3372e39c8eb27f28fdd82d46831df/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_focalLoss.py#L187
  2. change the subsequent selection of logits to: cls_label_input = logits[cls] in this line: https://github.com/MIC-DKFZ/nnUNet/blob/627c2473a0c3372e39c8eb27f28fdd82d46831df/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_focalLoss.py#L195
FabianIsensee commented 3 years ago

Hi @joeranbosma , thank you for finding this! I have never really used this loss implementation and therefore not noticed something is off. Can you please compare the implementation with the one in https://github.com/JunMa11/SegLoss? If they are both similar (in terms of segmentation performance) you could do a PR so that your contribution is visible to everyone. I would be happy to accept it! Best, Fabian

joeranbosma commented 3 years ago

Hi @FabianIsensee, thanks for your reply and linking the repo. Regarding the unexpected reshuffling, the approach in the linked repo is similar to what I proposed. As for the segmentation performance, I am currently running an experiment with the patched version of the current implementation, and another experiment with the implementation from SegLoss. I will update this thread once these runs finish (i.e., in a couple of days).

The implementation you referenced seems like a more general form than the one currently in the nnUNet repo, so if performance is similar, it's probably best to adopt the implementation from SegLoss.

FabianIsensee commented 3 years ago

Thank you!

The implementation you referenced seems like a more general form than the one currently in the nnUNet repo, so if performance is similar, it's probably best to adopt the implementation from SegLoss.

makes sense. If you want you can do a PR and I would be happy to accept it. If you do please make sure to reference the segLoss repo so that Jun get's his credit as well

joeranbosma commented 3 years ago

Hi @FabianIsensee. The evaluation took longer than expected but is finished now.

I considered two implementations: 1) Fixing the existing implementation (does not support label smoothing) 2) The implementation from SegLoss (uses label smoothing of 1e-5)

The performance in terms of average global DSC (estimated) was very close across all five of my folds. I have included the progress.py images, but differences are difficult to see. In terms of AUROC/FROC performance, both implementations were also very close.

I have created a pull request for the version from SegLoss, because that implementation also supports label smoothing. In my experiments, the AUROC/FROC performance was slightly better for no label smoothing. However, this is easily adapted in the SegLoss implementation.

combined copy

FabianIsensee commented 2 years ago

Thanks for the PR, I have merged it already ;-) Good work!

hreso110100 commented 2 years ago

Hi @FabianIsensee @joeranbosma was implementation of nnUNetTrainerV2_focalLoss for multiclass fixed ? My network is not learning at all even the loss is descreasing.