MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.19k stars 142 forks source link

About “KeyError: 'MM_VSSM is not in the mmseg::model registry.” and how to extract the middle layer #325

Open LogSSim opened 2 weeks ago

LogSSim commented 2 weeks ago

When I used the pre-trained weights on the Segmentation task, I got this: " KeyError: 'MM_VSSM is not in the mmseg::model registry. Please check whether the value of MM_VSSM is correct or it was registered as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' "

And if I want to get the middle layers. What could I do?

LogSSim commented 2 weeks ago

I encountered an issue while loading the pretrained model for segmentation.

‘’channel_first = True model = vmamba.Backbone_VSSM( depths=[2, 2, 18, 2], dims=15, drop_path_rate=0.6, patch_size=4, in_chans=3, num_classes=150, ssm_d_state=1, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", ssm_conv=3, ssm_conv_bias=False, ssm_drop_rate=0.0, ssm_init="v0", forward_type="v05_noz", mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, gmlp=False, patch_norm=True, norm_layer=("ln2d" if channel_first else "ln"), downsample_version="v3", patchembed_version="v2", use_checkpoint=False, posembed=False, imgsize=224, )

pretrained_weights = torch.load('/data3/urs/code/RVT/rvt/base_vmamba.pth') model.load_state_dict(pretrained_weights)‘’

" File "test_mamba.py", line 90, in model.load_state_dict(pretrained_weights) File "/home/urs/ENTER/envs/rvt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2201, in load_state_dict load(self, state_dict) File "/home/urs/ENTER/envs/rvt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2189, in load load(child, child_state_dict, child_prefix) # noqa: F821 File "/home/urs/ENTER/envs/rvt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2189, in load load(child, child_state_dict, child_prefix) # noqa: F821 File "/home/urs/ENTER/envs/rvt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2189, in load load(child, child_state_dict, child_prefix) # noqa: F821 [Previous line repeated 3 more times] File "/home/urs/ENTER/envs/rvt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2183, in load module._load_from_state_dict( File "/data3/urs/code/RVT/rvt/libs/VMamba/classification/models/vmamba.py", line 48, in _load_from_state_dict state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape) "