MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.17k stars 135 forks source link

Cannot load model #282

Open KLKb7 opened 3 months ago

KLKb7 commented 3 months ago

Traceback (most recent call last): File "/root/miniconda3/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3553, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "", line 1, in runfile('/root/autodl-tmp/project/ACC-UNet-main/Experiments/test_model.py', wdir='/root/autodl-tmp/project/ACC-UNet-main/Experiments') File "/root/.pycharm_helpers/pydev/_pydev_bundle/pydev_umd.py", line 198, in runfile pydev_imports.execfile(filename, global_vars, local_vars) # execute the script File "/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile exec(compile(contents+"\n", file, 'exec'), glob, loc) File "/root/autodl-tmp/project/ACC-UNet-main/Experiments/test_model.py", line 270, in model.load_state_dict(checkpoint['state_dict']) File "/root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Mamba2Unet: Missing key(s) in state_dict: "mamba_unet.norm.weight", "mamba_unet.norm.bias", "mamba_unet.norm_up.weight", "mamba_unet.norm_up.bias". Unexpected key(s) in state_dict: "mamba_unet.classifier.norm.weight", "mamba_unet.classifier.norm.bias", "mamba_unet.classifier.norm_up.weight", "mamba_unet.classifier.norm_up.bias". Process finished with exit code -1 When I trained your model and then tested it by loading the model, an error occurred during loading. What could be the cause of this error, and how can it be resolved?

MzeroMiko commented 2 months ago

Though I am not familiar with that project, I guess that you can just change the keys in state_dict from "mamba_unet.classifier." to "mamba_unet." to solve the problem.

That is:

state_dict = {("mamba_unet." + k[len("mamba_unet.classifier."):] if k.startswith("mamba_unet.classifier.") else k): v for k, v in state_dict.items() }