Luffy03 / VoCo

[CVPR 2024] VoCo: A Simple-yet-Effective Volume Contrastive Learning Framework for 3D Medical Image Analysis
Apache License 2.0
116 stars 8 forks source link

About the Pre-trained weight #7

Closed zzzjjj98 closed 21 hours ago

zzzjjj98 commented 4 months ago

Hello, thanks for your great work! Currently the publicly available pre-training weights are trained based on CT data with input channel 1. However, the downstream task of Brats21 is a 4-modal data with 4 channel inputs. Using the “Load Pre-trained weight” code you provided will result in a parameter mismatch error. How should the CT pre-training weights be applied to the MRI data? Thank you!

Luffy03 commented 4 months ago

Hi, many thanks for your kind attention to our work! For MRI training, please refer to these two closed issues: https://github.com/Luffy03/VoCo/issues/2 https://github.com/Luffy03/VoCo/issues/3 Feel free to raise your concerns here!

zzzjjj98 commented 4 months ago

Hi, many thanks for your kind attention to our work! For MRI training, please refer to these two closed issues: #2 #3 Feel free to raise your concerns here! Thank you for your answer. I would like to test the segmentation effect (without fine-tuning) using the VoCo_10k.pt you provided directly on 4 modalities of MRI data. The code to load the weights is as follows: model_dict = torch.load('./pretrained_models/VoCo_10k.pt', map_location=torch.device('cpu')) state_dict = model_dict state_dict = delete_patch_embed(state_dict) if "module." in list(state_dict.keys())[0]: print("Tag 'module.' found in state dict - fixing!") for key in list(state_dict.keys()): state_dict[key.replace("module.", "")] = state_dict.pop(key) if "swin_vit" in list(state_dict.keys())[0]: print("Tag 'swin_vit' found in state dict - fixing!") for key in list(state_dict.keys()): state_dict[key.replace("swin_vit", "swinViT")] = state_dict.pop(key) model.load_state_dict(state_dict, strict=False) print("Using pretrained voco ema self-supervised Swin UNETR backbone weights !")

After using “delete_patch_embed”, I still get the following error: ![Uploading 微信图片_20240513154245.png…]()

RuntimeError: Error(s) in loading state_dict for SwinUNETR: size mismatch for encoder1.layer.conv1.conv.weight: copying a param with shape torch.Size([48, 1, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([48, 4, 3, 3, 3]). size mismatch for encoder1.layer.conv3.conv.weight: copying a param with shape torch.Size([48, 1, 1, 1, 1]) from checkpoint, the shape in current model is torch.Size([48, 4, 1, 1, 1]).

Luffy03 commented 4 months ago

Many thanks for pointing out our mistake! The previous version of function delete is used for SwinunetrV1 and our 1.6k pre-training. I have updated it as https://github.com/Luffy03/VoCo/blob/5cb1297095f0fc55dc81b4e40640e4ee80e2c4e0/Finetune/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_swin.py#L153. image

zzzjjj98 commented 4 months ago

Many thanks for pointing out our mistake! The previous version of function delete is used for SwinunetrV1 and our 1.6k pre-training. I have updated it as

https://github.com/Luffy03/VoCo/blob/5cb1297095f0fc55dc81b4e40640e4ee80e2c4e0/Finetune/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer_swin.py#L153

. image

Thanks for the answer, the code has run through, much appreciated!

zzzjjj98 commented 3 months ago

Hello, thanks for your great work! I tried to train voco on the multimodal Brats dataset with the following train_transforms, but it seems that only one modality (one channel) can be loaded. May I know what processing should be done if I use multimodal MRI for pre-training? ![Uploading ooo.png…]()

zzzjjj98 commented 3 months ago

Hello, thanks for your great work! I tried to train voco on the multimodal Brats dataset with the following train_transforms, but it seems that only one modality (one channel) can be loaded. May I know what processing should be done if I use multimodal MRI for pre-training? ooo

Luffy03 commented 3 months ago

Sorry for my late reply. To pre-train on MRI data, you need to change 'in_channels' to 4 if you have 4 sequences or (modalities) as in https://github.com/Luffy03/VoCo/blob/94ed426bec328b7b9b5ddcf25b43fc14f27672ab/voco_train.py#L131

We also collect large-scale MRI data for pre-training but we have not yet achieved good results. If you have further problems with it, feel free to raise them here.