Project-MONAI / research-contributions

Implementations of recent research prototypes/demonstrations using MONAI.
https://monai.io/
Apache License 2.0
1.03k stars 336 forks source link

UNETR: Coping with limited data #32

Closed bwittmann closed 2 years ago

bwittmann commented 2 years ago

Hello, First of all, thank you very much for you great work!

I would have a question regarding how you coped with the issue of having only very limited training data available (spleen segmentation: 41 CT scans). Transformer based architectures like ViT or also detectors like DeTr have shown to only perform well when there is huge amount of labeled data available (DeTr lower bound of 2D images ~15k to train from scratch) and are known to converge very slowly. So I would think that training a 3D transformer based architecture like UNETR would even be more data hungry and result in overfitting as they converge so slowly and there is only limited data available.

So my question is basically: what are in your opinion the key-factors of the success of your approach when it comes to limited data? Is the random sampling to 96x96x96 the main factor that tackles this issue? Wouldn't the performance increase if you wouldn't do random sampling and instead use the whole ct scan as input to have complete global information for attention?

Furthermore, I would be interested in why your transformer encoder converges so quickly (10h) in comparison to original ViT.

I would be very happy if you could answer my questions. BR Bastian

overbestfitting commented 2 years ago

I think they did a 10-model ensemble and produced the best accuracy.

overbestfitting commented 2 years ago

Hello, First of all, thank you very much for you great work!

I would have a question regarding how you coped with the issue of having only very limited training data available (spleen segmentation: 41 CT scans). Transformer based architectures like ViT or also detectors like DeTr have shown to only perform well when there is huge amount of labeled data available (DeTr lower bound of 2D images ~15k to train from scratch) and are known to converge very slowly. So I would think that training a 3D transformer based architecture like UNETR would even be more data hungry and result in overfitting as they converge so slowly and there is only limited data available.

So my question is basically: what are in your opinion the key-factors of the success of your approach when it comes to limited data? Is the random sampling to 96x96x96 the main factor that tackles this issue? Wouldn't the performance increase if you wouldn't do random sampling and instead use the whole ct scan as input to have complete global information for attention?

Furthermore, I would be interested in why your transformer encoder converges so quickly (10h) in comparison to original ViT.

I would be very happy if you could answer my questions. BR Bastian

I think you pointed a very good point. If you read the paper, when they trained with only 30 cases, for the standarded challange, it produced around 0.85 accuracy. When they simply increase the training data to 80, it produced 0.89 for the UNETR. Of course for all of them, they used a 10-model ensemble inference. I am more curious about their swin-UNETR results though.

bwittmann commented 2 years ago

@overbestfitting Thanks for your message! So do you think that the ensemble inference is what allows us to cope with a dataset this limited? In the code I didn't find the part referring to ensembles so far. I would think that an ensemble inference only leads to a slight improvement.

overbestfitting commented 2 years ago

@overbestfitting Thanks for your message! So do you think that the ensemble inference is what allows us to cope with a dataset this limited? In the code I didn't find the part referring to ensembles so far. I would think that an ensemble inference only leads to a slight improvement.

I am not sure! But I would guess their additional 50 patients dominate the accuracy, according to the UNETR paper.

ahatamiz commented 2 years ago

Hi @bwittmann,

Thanks for the insightful comments. I'd try to address each item in the following:

1) I believe the use of a Conv-based decoder, as in UNETR, introduces desirable properties such as inductive image bias that allow the model to cope with different dataset sizes. For instance, BTCV is a small dataset, yet the performance is very competitive.

2) You are right in the sense that using the entire CT increases the performance. However, it increases the memory consumption significantly which hinders the training process. As such, randomly cropped samples are used. Even in this case with the cropped input, UNETR with its ViT encoder has a larger receptive field than those of CNNs with limited kernel sizes (e.g. 3 x 3 x 3). Hence, it is still beneficial to use ViT-based encoder for feature extraction. I'd like to mention that memory-aware ViT for 3D medical image analysis seems to be a very nice research topic which requires further work.

3) Similar to 1, the Conv-based decode plays an integral role for faster convergence. However, pre-training the standalone ViT still requires lots of data, which is expected.

I hope I was able to answer some of these questions.

Kind Regards

ahatamiz commented 2 years ago

Hi @overbestfitting,

Thanks for the comment. Similar to all previous state-of-the-art models for BTCV, use of additional data is important, and we followed the same trend when submitting to the leaderboard. However, our models (i.e. UNETR or Swin UNETR) demonstrate state-of-the-art performance even within our internal limited dataset, comparing to other approaches such as nnUnet. Hence, the success of our work is not dependent on the use of extra data.

Kind Regards