MIC-DKFZ / nnUNet

Apache License 2.0
5.87k stars 1.76k forks source link

Dice score validation and test discrepancy #2045

Closed Peaceandmaths closed 5 months ago

Peaceandmaths commented 7 months ago

Thank you for the nnunet. I am training it on 1400 CT images of the brain for aneurysm (very small vessel malformations) segmentation.

I have a problem : during training the pseudo-dice is reaching about 85%, but when I test the model on test set I get only 45% dice score. What could be the cause for this discrepancy ? Is there anything I can do to figure it out ? Any advice would help 🙏

image image

My hypotheses :

Thanks in advance !

TaWald commented 7 months ago

Hey @Peaceandmaths , there are a variety of reasons that this may happen.

As you mentioned one reason could be the shift of hospitals between train/val and test-dataset. A few things can go wrong here

  1. The clinician from the other clinic could annotate differently than your local annotator. If the labeling protocol is not well defined this may lead to your test-dataset to have a slight ground truth shift, which would lead to performance degradation even if your model works perfectly on it. (There is hardly anything to be done about it, as this is just a matter of inter-rater variability)
  2. You could also have other image acquisition protocols (think other scanners) that may change the image and degrade your model performance. This could be addressed with stronger augmentation techniques. You should probably test the DA5 augmentation scheme as this proved useful before in such a case.

The shift between hospitals is not necessarily the only possible reason!


Q: Do you have a small subset of cases with annotations from the same clinic? How is the final validation performance on this data?

It may be the case that, if the lesions are small, nnU-Net's foreground sampling strategy makes it learn that the foreground is more likely in the center of the model. Now during validation you predict in the tiled fashion, no-longer having the centering that you do get during training.

To check for this you have to look at the predictions of the model. Do you see a grid-style pattern in the outputs? If so, this is likely the reason for the observed performance decrease and you should adapt the way nnU-Net samples fore-ground patches during training.

Peaceandmaths commented 7 months ago

Hi again,

I checked again and noticed that the validation dice and the test dice are consistent (45%). I run the prediction and evaluation on the training data itself and got 47% dice, not 85% ! If I understand correctly, the pseudo dice is not representative and is not supposed to be comparable with the val/test dice because of the 5 folds structure. We're using the training data for training and validation, so it's not really unseen data. What is pseudo-dice really capturing ?

A: I have annotations for all the images I have, including test images. The final validation is 45%. I don't see a grid-style pattern in the outputs. The problem is sometimes extra aneurysms are predicted in places they are not supposed to be (outside of the brain), or not detected, or detected falsely.

Q: My current hypothesis for this low score is that the nnunet default architecture would need to be adapted for this particular problem. As you mentioned, maybe there's a need to adapt the foreground sampling strategy ? How would I go about adapt the way nnU-Net samples fore-ground patches during training ? Can you give me a short stepwise guide or a link to see how to do it ?

Thanks !

TaWald commented 7 months ago

If I understand correctly, the pseudo dice is not representative and is not supposed to be comparable with the val/test dice because of the 5 folds structure. We're using the training data for training and validation, so it's not really unseen data. What is pseudo-dice really capturing ?

The main difference between pseudo-dice and the final dice is the context in which it is generated: The pseudo-dice is computed on crops generated during the training process. The final dice is generated on entire cases, which are generated through tiling into crops.

The most important difference between the two for you is likely, that nnU-Net has foreground oversampling and the way it utilizes this during the training pipeline. Here is nnU-Nets process for sampling during training to highlight the issue:

# Training loop
if sample_foreground:
   1. Choose a random pixel of a random foreground class
   2. Create a crop that is **centered** around that pixel with the wanted dimensions
else:
   Choose a random other crop from somewhere else in the patient

The issue with this is, that your model learns that the object of interest is located in the center of the crop, as whenever you choose to predict foreground your model will center the crop on it. The smaller your instances, the stronger this bias. If your model learns to rely on this bias, it will fail on test-time as on test-time you predict cases in the tiled approach, taking away this crutch your model relies on.

My recommendation to rework the foreground sampling: nnU-net and introduce a random shift of the center-point by some margin, assuring it is still in the crop, but not close to the center.

You can start modding after nnU-Net selected the foreground class that is supposed to get oversampled https://github.com/MIC-DKFZ/nnUNet/blob/c7f85b729145d8b8ebf1786aa87c45398f7fbc76/nnunetv2/training/dataloading/base_data_loader.py#L124

if voxels_of_that_class is not None and len(voxels_of_that_class) > 0:
                selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
                # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
                # Make sure it is within the bounds of lb and ub
                # i + 1 because we have first dimension 0!

                 # YOUR CODE -- Insert some logic here to move the bounding box (used for cropping) in some way

                bbox_lbs = [max(lbs[i], selected_voxel[i + 1] - self.patch_size[i] // 2) for i in range(dim)]

BEFORE YOU DO THIS THOUGH: Check your predictions visually! Q: Do you see a grid pattern of the predictions in them? Q: If yes does it occur often? Q: Check if all cases have similar dice scores or do some fail catastrophically?

If it's equal, do the above


In your case there may also be other issues but I guess the sampling strategy may be the likeliest hypothesis.

Chasel-Chen commented 6 months ago

@Peaceandmaths Hello, I've encountered a similar issue as you have. Have you resolved it in the end?

TaWald commented 6 months ago

If any of you is interested in creating a pull-request e.g. like a "SmallInstanceTrainer" that uses a different sampling strategy it would be greatly appreciated 😄 🚀

Peaceandmaths commented 6 months ago

@Chasel-Chen From my understanding, the pseudo-dice score shouldn't be the expectation for the final validation or test dice. In stead go into results of your cross-fold validation and see the summary there to see which dice score to expect approximatively for your final test performance. In my case they are actually quite close ( 47% validation, 45% testing). Hope that helps.

TaWald commented 6 months ago

Also as previously mentioned, the inflated pseudo-dice numbers can be fixed (and maybe your overall model performance increased) by changing the sampling foreground sampling strategy of the patches. By not having the small instances always centered your model's pseudo dice likely decreases and your model may even learn the task better.

Chasel-Chen commented 6 months ago

@Peaceandmaths @TaWald Thank you both for sharing. I'll try out and adjust sampling strategies for my own dataset. If I find any effective methods, I'll share them with everyone. Best wishes!