Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.93k stars 1.09k forks source link

ViT pre-trained for 3D images #3947

Closed phcerdan closed 2 years ago

phcerdan commented 2 years ago

Is your feature request related to a problem? Please describe. After reading the original paper on visual transformers (link below), they seem to excel when trained over large datasets. That makes sense because they have to learn from scratch the structure of the image (what patches are neighbors of other patches, etc).

Describe the solution you'd like I would like to find out if there is any pre-trained ViT for 3D images. And if yes, how can they be re-used in Monai.

Describe alternatives you've considered I have explored the web with this same question, but without much luck. This https://github.com/lucidrains/vit-pytorch/issues/125 suggests that a pretrained 2D ViT could be adapted to 3D. But of course, I guess that implementation would differ from Monai? Any hint on how to do this for reuse in Monai?

Additional context Original paper on ViT, for reference: https://arxiv.org/abs/2010.11929

EDIT: pinging @ahatamiz as the implementor of swin-unetr (thanks!)

phcerdan commented 2 years ago

https://github.com/Project-MONAI/tutorials/tree/master/self_supervised_pretraining

provides a pre-trained model of ViTAutoEnc in 3D CT Images. This should do!

The only missing part is how to adapt this to the UNETR that uses the regular ViT. I will post the solution as soon as I find it.

Nic-Ma commented 2 years ago

Thanks for the sharing. CC @ahatamiz

ahatamiz commented 2 years ago

Hi @phcerdan

Thanks for the comments. ViTs that are pre-trained on a large corpus of data enjoy a better performance. With self-supervised learning, it creates the opportunity to further boost the model performance, without needing to have annotated datasets -- which is sometimes hard to obtain in the medical domain.

As you have pointed out, this tutorial is specifically dedicated for self-supervised learning of UNETR backbone, which is a 3D ViT model. A 2D ViT model needs further modifications in it's patch embedding layer to expand to 3D. However, model performance may suffer due to lack of access to spatial context during pre-training.

Regarding your second question, once the 3D ViT is pre-trained, you can use the fine-tuning script for a down-stream task of segmentation using BTCV dataset ( similar to UNETR tutorial). As indicated here, we achieve considerable relative performance improvement in terms of Dice score for varying number of training samples, hence validating the effectiveness of the pre-training pipeline.

Lastly, I will be working on integrating Swin UNETR with MONAI core very soon. Please stay tuned.

fleecedragoon commented 1 year ago

Hi, are the pretrained ViT weights available without having to run the training script? thank you