MIC-DKFZ / nnUNet

Apache License 2.0
5.98k stars 1.78k forks source link

How does nnUNet performance scale to having more training data #1171

Closed charlesmoatti closed 1 year ago

charlesmoatti commented 2 years ago

Just a small remark/question I would like to make given that I did not find it in the previous issues. This is what I have observed so far working with nnUNet.

Giving more data to training of a nnUNet model e.g. go from 80 3D training images to 800 should not have much impact given the strategy of a fixed number of batches per epoch (250) and the multitude of data augmentation techniques used. The fact that we do not ensure that nnUNet is exposed to all images at training time but perform a random sampling for each patch reinforces this thought for me. I think that the performance plateau of nnUNet is quickly reached with a small number of samples (let's say in the ~50 3d images range, depending on the exact task) and that past this point adding more data is useless (for performance metrics such as dice score at least).

Please correct me if I am wrong and any thoughts are welcome.

Any thoughts on how nn-UNet performance/robustness would scale to having lot more data in training? Is the plateau of nnUNet performance quickly reached with relatively few training data?

thijsgelton commented 2 years ago

I do not think that this is a nnUNet specific thing. It could be in your case that your development dataset perfectly resembles the dataset you are testing on (no shift). Additionally, if you have an easy problem, then you also do not require too much data. In the case where your dataset contains a lot of noisy labels or is in general a difficult problem, adding more data should definitely improve results. This is all regardless of nnUNet.

Joeycho commented 2 years ago

I observed the similar phenomenon and I even dealt with only 11-12 3D images. I think that this random patch sampling in Dataloader could be one of the reason for high performance with relatively few training data. Especially if we have high quality(big enough) few 3D images, we could get quite nice final set of patches after cropping patches and data augmentation. Even we can increase 250 (the number of batches) to 500, 750.. and so on. If these few data is quite representatives from whole distribution of data(only exists ideally), this increase might be helpful.

So, I agree that increasing the number of data is sometimes not helpful if you can achieve the similar performance with few training data. Plus, it is getting hard to verify which label is absolutely better than the others (prediction labels and noisy GT labels) if we have noisy GT labels. So, I think increasing the number of data makes sense once there is obvious poor performance with few training data.

FabianIsensee commented 2 years ago

It's all about covering the natural distribution of your problem.

If your target is a regular structure, like an organ, it will always look the same and be located at the same place in the image. You don't need many training cases for that and increasing your dataset size won't improve the results.

If your target is a regular structure like an organ, but you want to be able to handle all the different CT variants and scanners there are in the world and you want the algorithm to work just as well for people of Asian, African or Caucasian origin and you also want it to be robust with respect to diseases then you have to have more training data to cover all these aspects

If you are interested in sporadically appearing small tumors that can crop up in any body region and are really hard to see you need copious amounts of training data covering all the different locations and contexts in which the target structure can be found.

So really the number of training cases alone is completely useless. It's all about the required diversity of the training dataset. If the required diversity is low, few cases suffice. It it's high you need more.

Do you have enough? Only one way to find out. Measure the inter-rater variability on your data. If nnU-Net is just as good then you have enough cases. If not then you need more. Or a better segmentation method :-)