Project-MONAI / research-contributions

Implementations of recent research prototypes/demonstrations using MONAI.
https://monai.io/
Apache License 2.0
1.02k stars 334 forks source link

How to use SwinTransformer pre-train ckpt (Supervised pre-trained on ImageNet-1k) #141

Closed jessie-chen99 closed 1 year ago

jessie-chen99 commented 2 years ago

Hi!Thanks for your great work! when i use this Swin UNETR model(https://github.com/Project-MONAI/MONAI/blob/1.0.0/monai/networks/nets/swin_unetr.py), there is a ckpt size mismatch problem. i cannot load the _swin_tiny_patch4_window7224.pth into the self.swinViT part. Because the _embeddim in SwinTransformer are 96/128..., but in Swin UNETR it is 48 (called _featuresize). Could you teach me how to load this kind of pretrain ckpts? Only when using Swin UNETR with _featuresize=96 , I can load _swin_tiny_patch4_window7224.pth
So, does that means: "When using Swin UNETR with _featuresize=48 I only can train from scratch" ? thx!!

I use this SwinUNETR model in a recommend way, and the dataset i use is BraTS2019 3D input: ->Examples:: ->for 3D 4-channel input with size (128,128,128), 3-channel output and (2,4,2,2) layers in each stage. ->net = SwinUNETR(img_size=(128,128,128), in_channels=4, out_channels=3, depths=(2,4,2,2))

Screenshots image

Environment (please complete the following information):

tangy5 commented 2 years ago

Hi @jessie-chen99 , thanks for the interests of the work.

There should a way for loading Swin UNETR encoder: see these models, https://github.com/Project-MONAI/research-contributions/tree/main/SwinUNETR/BTCV#models The Swin UNETR/Base should be the one for feature=48.

by given a 3D model of feature=48

    # prepare the 3D model
    model = SwinUNETR(img_size=(args.roi_x, args.roi_y, args.roi_z),
                    in_channels=1,
                    out_channels=NUM_CLASS,
                    feature_size=48,
                    )

The following code should be able to load that base Swin ViT

        #Load pre-trained weights
    store_dict = model.state_dict()
    model_dict = torch.load(pretrained_path)["state_dict"]
    for key in model_dict.keys():
        if 'out' not in key:
            store_dict[key] = model_dict[key]

    model.load_state_dict(store_dict)
    print('Use pretrained weights')

This one should at least work with 96x96x96 patches size. Also, please note the pre-trained weights are trained using CT scans, there will be domain gaps since the BRATS data are MRIs. It would be better to pretrain your own data using MRI scans or train from scratch. Thanks