MIC-DKFZ / nnUNet

Apache License 2.0
5.96k stars 1.77k forks source link

Inference fails dramatically but unpredictably when a class is missing or small #2365

Open pwrightkcl opened 4 months ago

pwrightkcl commented 4 months ago

I have a multi-class segmentation model that seems to have trained on the data but fails dramatically on a subset of test cases. I am hoping someone can recognise the failure mode and help me fix it.

I am training nnUNet to segment four classes. One class is spatially heterogeneous and not present in every image. The training curves look reasonable and inference on my unseen evaluation set works on about half the set, even for some quite difficult (to a human) images. In the other half of the set it fails dramatically, with classes 2-4 covering most of the image, seemingly just following an intensity split.

The failure cases are mostly those where the variable class is absent or small. The working cases, however, include some images with missing or small class 4 as well. (The residual encoder plan does better at this than the vanilla plan but both fail similarly.) Clearly the model has learnt something. I am wondering if there is some way to recover the failure cases, either at training or inference.

For more detail on my use case, I am segmenting CT brain images with intracerebral haemorrhage into four classes:

  1. parenchymal haemorrhage (blue)
  2. calcification (green)
  3. ventricular CSF (yellow)
  4. ventricular blood (red)

Label 4, ventricular blood, is inherently quite visually heterogeneous and does not occur in all cases.

Here is an example showing the GT labels and prediction for a failure case. There are just a few tiny bits of ventricular blood. Note that class one (haemorrhage in blue) is predicted correctly. This is the same for all failures: only classes 2-4 end up looking weird

sub-0d0ba99_gt

sub-0d0ba99_pred

I have looked at the image intensities for the test images, as I'm told this failure mode suggests an issue with intensity scaling. They all have appropriate Hounsfield unit values and do not vary systematically between successful and failed inference cases. As normal for CT, the background values are -1024. My dataset.json specifies CT modality and I know nnUNet scales accordingly. I am not sure how scaling is applied at inference.

I would welcome any advice and am happy to dig into logs or rerun stuff with verbose / debug logging, if someone can advise me where to look.

gy-xinchen commented 4 months ago

I think I have a similar problem. I used the residual encoder plan with BN and used data augmentation from the Paper 《Studying Robustness of Semantic Segmentation Under Domain Shift in Cardiac MRI. Compared with default nnUNet, some samples segmented failed. improved_nnUNet default_nnUNet_failed improve_nnUNet_failed

bruniss commented 1 day ago

Try these with region based training, you may be doing this already but it appears you have a class assigned to a region that in other slices would reasonably be segmented as a single class, which confuses the model I think. If you marked the area with the blood as blood AND the other class, and taught it the region, this may work better