junyuchen245 / TransMorph_Transformer_for_Medical_Image_Registration

TransMorph: Transformer for Unsupervised Medical Image Registration (PyTorch)
MIT License
432 stars 71 forks source link

Pre-trained Model Weights #17

Closed golnooshAbd closed 2 years ago

golnooshAbd commented 2 years ago

I was trying to inference pre-trained model weights of TransMorph (simple) but after running infer_TransMorph.py in terminal, I faced these errors:

Traceback (most recent call last): File "./infer_TransMorph.py", line 99, in main() File "./infer_TransMorph.py", line 35, in main model.load_state_dict(best_model) File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1483, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for TransMorph: size mismatch for up0.conv1.0.weight: copying a param with shape torch.Size([384, 1152, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([384, 768, 3, 3, 3]). size mismatch for up1.conv1.0.weight: copying a param with shape torch.Size([192, 576, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([192, 384, 3, 3, 3]). size mismatch for up2.conv1.0.weight: copying a param with shape torch.Size([96, 288, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([96, 192, 3, 3, 3]).

It seems to be a problem with pre-trained weights, would you please check and help me fix it. Thanks in advanced.

junyuchen245 commented 2 years ago

Hi @junyuchen245 ,

You need to modify this configuration according to the pre-trained weights: https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/f89c4e46ae77e17952da5350a86ac95b1ff0351a/IXI/TransMorph/infer_TransMorph.py#L31

If you change this to the line below, then the pre-trained weight should work.

config = CONFIGS_TM['TransMorph']

Each of these strings corresponds to a specific configuration of TransMorph: https://github.com/junyuchen245/TransMorph_Transformer_for_Medical_Image_Registration/blob/ec422e6b0beb7989340c24659089c3e889ad9b66/IXI/TransMorph/models/TransMorph.py#L883-L894 However, please note that we only provided the pre-trained weights for the base model 'TransMorph'.

Thanks, Junyu