Sssssuperior / VSCode

Code release for "VSCode: General Visual Salient and Camouflaged Object Detection with 2D Prompt Learning"
MIT License
33 stars 2 forks source link

测试问题 #7

Open SilentWhiteRabbit opened 4 months ago

SilentWhiteRabbit commented 4 months ago

你好,我在运行你的代码进行测试的过程中,在加载权重这步出现了问题。由于没有预训练权重,所以我把swin_transformer的pretrained参数改为false,其他超参数并无变化。在运行下面这段代码时出现报错:

    net = ImageDepthNet(args)
    net.cuda()
    net.eval()
    # load model (multi-gpu)
    model_path = './checkpoint/RGB_VST_T.pth'
    state_dict = torch.load(model_path)
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # load params
    net.load_state_dict(new_state_dict)
    print('Model loaded from {}'.format(model_path))

出现报错如下

Traceback (most recent call last): File "G:/software_package/VSCode-main/train_test_eval.py", line 80, in net.load_state_dict(new_state_dict) File "D:\anaconda3\envs\shiyanshi\lib\site-packages\torch\nn\modules\module.py", line 1223, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for ImageDepthNet: Missing key(s) in state_dict: "rgb_backbone.head.weight", "rgb_backbone.head.bias". Unexpected key(s) in state_dict: "rgb_backbone.layers.2.blocks.6.norm1.weight", "rgb_backbone.layers.2.blocks.6.norm1.bias", "rgb_backbone.layers.2.blocks.6.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.6.attn.relative_position_index", "rgb_backbone.layers.2.blocks.6.attn.qkv.weight", "rgb_backbone.layers.2.blocks.6.attn.qkv.bias", "rgb_backbone.layers.2.blocks.6.attn.proj.weight", "rgb_backbone.layers.2.blocks.6.attn.proj.bias", "rgb_backbone.layers.2.blocks.6.norm2.weight", "rgb_backbone.layers.2.blocks.6.norm2.bias", "rgb_backbone.layers.2.blocks.6.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.6.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.6.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.6.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.7.attn_mask", "rgb_backbone.layers.2.blocks.7.norm1.weight", "rgb_backbone.layers.2.blocks.7.norm1.bias", "rgb_backbone.layers.2.blocks.7.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.7.attn.relative_position_index", "rgb_backbone.layers.2.blocks.7.attn.qkv.weight", "rgb_backbone.layers.2.blocks.7.attn.qkv.bias", "rgb_backbone.layers.2.blocks.7.attn.proj.weight", "rgb_backbone.layers.2.blocks.7.attn.proj.bias", "rgb_backbone.layers.2.blocks.7.norm2.weight", "rgb_backbone.layers.2.blocks.7.norm2.bias", "rgb_backbone.layers.2.blocks.7.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.7.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.7.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.7.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.8.norm1.weight", "rgb_backbone.layers.2.blocks.8.norm1.bias", "rgb_backbone.layers.2.blocks.8.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.8.attn.relative_position_index", "rgb_backbone.layers.2.blocks.8.attn.qkv.weight", "rgb_backbone.layers.2.blocks.8.attn.qkv.bias", "rgb_backbone.layers.2.blocks.8.attn.proj.weight", "rgb_backbone.layers.2.blocks.8.attn.proj.bias", "rgb_backbone.layers.2.blocks.8.norm2.weight", "rgb_backbone.layers.2.blocks.8.norm2.bias", "rgb_backbone.layers.2.blocks.8.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.8.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.8.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.8.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.9.attn_mask", "rgb_backbone.layers.2.blocks.9.norm1.weight", "rgb_backbone.layers.2.blocks.9.norm1.bias", "rgb_backbone.layers.2.blocks.9.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.9.attn.relative_position_index", "rgb_backbone.layers.2.blocks.9.attn.qkv.weight", "rgb_backbone.layers.2.blocks.9.attn.qkv.bias", "rgb_backbone.layers.2.blocks.9.attn.proj.weight", "rgb_backbone.layers.2.blocks.9.attn.proj.bias", "rgb_backbone.layers.2.blocks.9.norm2.weight", "rgb_backbone.layers.2.blocks.9.norm2.bias", "rgb_backbone.layers.2.blocks.9.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.9.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.9.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.9.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.10.norm1.weight", "rgb_backbone.layers.2.blocks.10.norm1.bias", "rgb_backbone.layers.2.blocks.10.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.10.attn.relative_position_index", "rgb_backbone.layers.2.blocks.10.attn.qkv.weight", "rgb_backbone.layers.2.blocks.10.attn.qkv.bias", "rgb_backbone.layers.2.blocks.10.attn.proj.weight", "rgb_backbone.layers.2.blocks.10.attn.proj.bias", "rgb_backbone.layers.2.blocks.10.norm2.weight", "rgb_backbone.layers.2.blocks.10.norm2.bias", "rgb_backbone.layers.2.blocks.10.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.10.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.10.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.10.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.11.attn_mask", "rgb_backbone.layers.2.blocks.11.norm1.weight", "rgb_backbone.layers.2.blocks.11.norm1.bias", "rgb_backbone.layers.2.blocks.11.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.11.attn.relative_position_index", "rgb_backbone.layers.2.blocks.11.attn.qkv.weight", "rgb_backbone.layers.2.blocks.11.attn.qkv.bias", "rgb_backbone.layers.2.blocks.11.attn.proj.weight", "rgb_backbone.layers.2.blocks.11.attn.proj.bias", "rgb_backbone.layers.2.blocks.11.norm2.weight", "rgb_backbone.layers.2.blocks.11.norm2.bias", "rgb_backbone.layers.2.blocks.11.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.11.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.11.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.11.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.12.norm1.weight", "rgb_backbone.layers.2.blocks.12.norm1.bias", "rgb_backbone.layers.2.blocks.12.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.12.attn.relative_position_index", "rgb_backbone.layers.2.blocks.12.attn.qkv.weight", "rgb_backbone.layers.2.blocks.12.attn.qkv.bias", "rgb_backbone.layers.2.blocks.12.attn.proj.weight", "rgb_backbone.layers.2.blocks.12.attn.proj.bias", "rgb_backbone.layers.2.blocks.12.norm2.weight", "rgb_backbone.layers.2.blocks.12.norm2.bias", "rgb_backbone.layers.2.blocks.12.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.12.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.12.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.12.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.13.attn_mask", "rgb_backbone.layers.2.blocks.13.norm1.weight", "rgb_backbone.layers.2.blocks.13.norm1.bias", "rgb_backbone.layers.2.blocks.13.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.13.attn.relative_position_index", "rgb_backbone.layers.2.blocks.13.attn.qkv.weight", "rgb_backbone.layers.2.blocks.13.attn.qkv.bias", "rgb_backbone.layers.2.blocks.13.attn.proj.weight", "rgb_backbone.layers.2.blocks.13.attn.proj.bias", "rgb_backbone.layers.2.blocks.13.norm2.weight", "rgb_backbone.layers.2.blocks.13.norm2.bias", "rgb_backbone.layers.2.blocks.13.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.13.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.13.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.13.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.14.norm1.weight", "rgb_backbone.layers.2.blocks.14.norm1.bias", "rgb_backbone.layers.2.blocks.14.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.14.attn.relative_position_index", "rgb_backbone.layers.2.blocks.14.attn.qkv.weight", "rgb_backbone.layers.2.blocks.14.attn.qkv.bias", "rgb_backbone.layers.2.blocks.14.attn.proj.weight", "rgb_backbone.layers.2.blocks.14.attn.proj.bias", "rgb_backbone.layers.2.blocks.14.norm2.weight", "rgb_backbone.layers.2.blocks.14.norm2.bias", "rgb_backbone.layers.2.blocks.14.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.14.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.14.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.14.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.15.attn_mask", "rgb_backbone.layers.2.blocks.15.norm1.weight", "rgb_backbone.layers.2.blocks.15.norm1.bias", "rgb_backbone.layers.2.blocks.15.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.15.attn.relative_position_index", "rgb_backbone.layers.2.blocks.15.attn.qkv.weight", "rgb_backbone.layers.2.blocks.15.attn.qkv.bias", "rgb_backbone.layers.2.blocks.15.attn.proj.weight", "rgb_backbone.layers.2.blocks.15.attn.proj.bias", "rgb_backbone.layers.2.blocks.15.norm2.weight", "rgb_backbone.layers.2.blocks.15.norm2.bias", "rgb_backbone.layers.2.blocks.15.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.15.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.15.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.15.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.16.norm1.weight", "rgb_backbone.layers.2.blocks.16.norm1.bias", "rgb_backbone.layers.2.blocks.16.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.16.attn.relative_position_index", "rgb_backbone.layers.2.blocks.16.attn.qkv.weight", "rgb_backbone.layers.2.blocks.16.attn.qkv.bias", "rgb_backbone.layers.2.blocks.16.attn.proj.weight", "rgb_backbone.layers.2.blocks.16.attn.proj.bias", "rgb_backbone.layers.2.blocks.16.norm2.weight", "rgb_backbone.layers.2.blocks.16.norm2.bias", "rgb_backbone.layers.2.blocks.16.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.16.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.16.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.16.mlp.fc2.bias", "rgb_backbone.layers.2.blocks.17.attn_mask", "rgb_backbone.layers.2.blocks.17.norm1.weight", "rgb_backbone.layers.2.blocks.17.norm1.bias", "rgb_backbone.layers.2.blocks.17.attn.relative_position_bias_table", "rgb_backbone.layers.2.blocks.17.attn.relative_position_index", "rgb_backbone.layers.2.blocks.17.attn.qkv.weight", "rgb_backbone.layers.2.blocks.17.attn.qkv.bias", "rgb_backbone.layers.2.blocks.17.attn.proj.weight", "rgb_backbone.layers.2.blocks.17.attn.proj.bias", "rgb_backbone.layers.2.blocks.17.norm2.weight", "rgb_backbone.layers.2.blocks.17.norm2.bias", "rgb_backbone.layers.2.blocks.17.mlp.fc1.weight", "rgb_backbone.layers.2.blocks.17.mlp.fc1.bias", "rgb_backbone.layers.2.blocks.17.mlp.fc2.weight", "rgb_backbone.layers.2.blocks.17.mlp.fc2.bias".

貌似是权重和模型不匹配,请问是我哪里改错了吗

Sssssuperior commented 4 months ago

你好,这应该是没有导入backbone的weight导致的,具体backbone的weight可以从https://github.com/microsoft/Swin-Transformer下载