Project-MONAI / research-contributions

Implementations of recent research prototypes/demonstrations using MONAI.
https://monai.io/
Apache License 2.0
1.03k stars 336 forks source link

Swin-UNETR pretraining model weights do not load into the model #346

Open Saqeeb95 opened 11 months ago

Saqeeb95 commented 11 months ago

Describe the bug The provided pre-trained Swin-UNETR weights do not load into a newly instantiated SSLHead model object. The naming scheme for the model state_dict keys is different between the provided weights and the instantiated SSLHead. Even when renaming the dict keys, there remains a mismatch between some layers; the weights file contains weights for fully connected layers whereas the instantiated model expects them to be linear layers. I think the architecture in the SSLHead in ssl_head.py differs from the architecture of the model that the provided model_swinvit.pt file comes from.

To Reproduce Steps to reproduce the behavior:

  1. Clone the repo from here
  2. Download the model_swinvit.pt file from here
  3. Place the model_swinvit.pt file in the /models/weights/ directory inside the /SwinUNETR/Pretrain/ directory (create the dirs if needed)
  4. Set up the datasets as per the instructions in the same repo
  5. cd into the /SwinUNETR/Pretrain/ directory and run the following command: python main.py --use_checkpoint --batch_size=3 --num_steps=450 --lrdecay --eval_num=100 --logdir="test8" --lr=0.000004 --roi_x=96 --roi_y=96 --roi_z=96 --lr_schedule="poly" --noamp --epochs=15000 --resume="./models/weights/model_swinvit.pt"

Expected behavior I expected that the model would train when given the "--resume" argument which pointed to the model_swinvit.pt file obtained from the Swin-UNETR pretraining page in the Pre-trained Models section.

Screenshots Here is the mismatching keys error I get when trying to load the model into a newly initialized SSLHead model: full_err This is pretty hard to read but the gist is that the naming scheme seems to differ between the provided weights and the SSLHead model when initialized. Also there are more "Unexpected" weights listed but I couldn't fit them into the screenshot.

I noticed that many of the layers just had the "swinViT" part replaced with "module" so I tried renaming the dict keys to match as well as I could. Here is a screenshot of the remaining mismatched keys afterwards: image The remaining mismatched keys seem to be due to the fact that the SSLHead model expects two linear layers per encoder level (e.g. "swinViT.layers1.0.blocks.0.mlp.linear1.weight"), whereas the provided weights have two fully connected layers instead (e.g. "swinViT.layers1.0.blocks.0.mlp.fc1.weight"). The last mismatches are the "swinViT.norm.weight", "swinViT.norm.bias" that don't appear to be needed

Environment (please complete the following information):

Additional context The model can train from scratch just fine, and I can also load checkpoints from my own previous runs. I just can't load the state_dict from the provided weights file.

prateekgrover-in commented 5 months ago

Hi Saqeeb, Were you able to find a solution to this issue? Facing the same problem!

Saqeeb95 commented 5 months ago

Unfortunately not. I tried renaming the dict keys in the loaded weights to match the ones the model expects, but some of them don't match up. I think the weights provided are actually for a slightly different architecture based on the differences between the expected dict keys and the ones in the weights file.

prateekgrover-in commented 5 months ago

Unfortunately not. I tried renaming the dict keys in the loaded weights to match the ones the model expects, but some of them don't match up. I think the weights provided are actually for a slightly different architecture based on the differences between the expected dict keys and the ones in the weights file.

Yeah, it copies 126 of the 159 weights from the file for me. With those missing weights it's very hard to fine-tune I think, or were you able to find another way?

Saqeeb95 commented 5 months ago

No sorry I wasn't. I had my own dataset of ~1500 images I could pretrain on so I just moved on and did my own pretraining.