MIC-DKFZ / nnUNet

Apache License 2.0
5.33k stars 1.63k forks source link

2D nnUNet predicting only one class #2263

Open JackRio opened 3 weeks ago

JackRio commented 3 weeks ago

I am training 2D nnUNet on medical imaging data with 6 class labels. The dataset consists of CT scans and the labels are segmentation masks for calcium in different arteries. I cross-check my data and labels in the preprocessed directory and run a debugger to see if the train_step function gets data with labels from other classes during training. I cannot seem to figure out what exactly is happening because this is what the pseudo-dice looks like. I trained one fold for all 1000 epochs still the model is only learning class 4. I am aware pseudo dice is not done on the entire dataset but patches but other classes have 0 loss throughout.

This is the same for all the epochs (I just picked a few epochs to show the training) ...... 2024-06-02 07:05:42.874473: Epoch 339
2024-06-02 07:05:43.096008: Current learning rate: 0.00689
2024-06-02 07:12:19.457671: train_loss -0.5111
2024-06-02 07:12:19.919749: val_loss -0.508
2024-06-02 07:12:19.938606: Pseudo dice [0.0, 0.0, 0.0, 0.8057, 0.0, 0.0]
2024-06-02 07:12:20.241984: Epoch time: 396.83 s
2024-06-02 07:12:22.293276:
2024-06-02 07:12:22.452150: Epoch 340
2024-06-02 07:12:22.472649: Current learning rate: 0.00688
2024-06-02 07:18:45.405543: train_loss -0.4954
2024-06-02 07:18:45.831223: val_loss -0.5073
2024-06-02 07:18:45.978944: Pseudo dice [0.0, 0.0, 0.0, 0.8139, 0.0, 0.0]
2024-06-02 07:18:46.109472: Epoch time: 383.11 s
2024-06-02 07:18:47.965937:
2024-06-02 07:18:47.997068: Epoch 341
2024-06-02 07:18:48.051786: Current learning rate: 0.00687
2024-06-02 07:25:12.848920: train_loss -0.5149
2024-06-02 07:25:13.089326: val_loss -0.4665
2024-06-02 07:25:13.369660: Pseudo dice [0.0, 0.0, 0.0, 0.8218, 0.0, 0.0]
2024-06-02 07:25:13.466279: Epoch time: 384.88 s
2024-06-02 07:25:15.081511:
2024-06-02 07:25:15.148523: Epoch 342
2024-06-02 07:25:15.368651: Current learning rate: 0.00686
2024-06-02 07:32:09.941649: train_loss -0.4962
2024-06-02 07:32:10.280238: val_loss -0.4776
2024-06-02 07:32:10.298611: Pseudo dice [0.0, 0.0, 0.0, 0.7984, 0.0, 0.0]
2024-06-02 07:32:10.340392: Epoch time: 414.86 s
........

Let me know if more information is needed. I will debug a bit more and see if there is anything wrong somewhere. Note: I am using out-of-the-box nnUNet for now and haven't changed anything.

JackRio commented 3 weeks ago

I think this has to do with the number of voxels available for training per class. Class 4 has the highest number of voxel in my dataset so maybe it's learning class 4 better. Also, I realized the patch size was too big for the model to learn any foreground class. So I reduce the patch size to 128 x 128 from 384x512. At first the model started to learn all classes but collapsed after 50 epochs to predicting class 4.

Any suggestions in this case what I should do?