Closed clh-b closed 3 years ago
Please try to modify this snippet to load the trained weights. Also, make sure you select the right architecture.
from models import M_SFANet
save_path = './models/checkpoint_best_MSFANet_B.pth' # Path to the trained weights. m = M_SFANet.Model().to(device) # Build the model. checkpoint = torch.load(save_path) m.load_state_dict(checkpoint['model'])
I downloaded the Shanghaitech A&B pretrained weights ("checkpoint_best_MSFANet_A.pth" & "checkpoint_best_MSFANet_B.pth") from the link provided but have error. How can I used them?
RuntimeError: Error(s) in loading state_dict for Model: Missing key(s) in state_dict: "vgg.conv1_1.conv.weight", "vgg.conv1_1.conv.bias", "vgg.conv1_1.bn.weight", "vgg.conv1_1.bn.bias", "vgg.conv1_1.bn.running_mean", "vgg.conv1_1.bn.running_var", "vgg.conv1_2.conv.weight", "vgg.conv1_2.conv.bias", "vgg.conv1_2.bn.weight", "vgg.conv1_2.bn.bias", "vgg.conv1_2.bn.running_mean", "vgg.conv1_2.bn.running_var", "vgg.conv2_1.conv.weight", "vgg.conv2_1.conv.bias", "vgg.conv2_1.bn.weight", "vgg.conv2_1.bn.bias", "vgg.conv2_1.bn.running_mean", "vgg.conv2_1.bn.running_var", "vgg.conv2_2.conv.weight", "vgg.conv2_2.conv.bias", "vgg.conv2_2.bn.weight", "vgg.conv2_2.bn.bias", "vgg.conv2_2.bn.running_mean", "vgg.conv2_2.bn.running_var", "vgg.conv3_1.conv.weight", "vgg.conv3_1.conv.bias", "vgg.conv3_1.bn.weight", "vgg.conv3_1.bn.bias", "vgg.conv3_1.bn.running_mean", "vgg.conv3_1.bn.running_var", "vgg.conv3_2.conv.weight", "vgg.conv3_2.conv.bias", "vgg.conv3_2.bn.weight", "vgg.conv3_2.bn.bias", "vgg.conv3_2.bn.running_mean", "vgg.conv3_2.bn.running_var", "vgg.conv3_3.conv.weight", "vgg.conv3_3.conv.bias", "vgg.conv3_3.bn.weight", "vgg.conv3_3.bn.bias", "vgg.conv3_3.bn.running_mean", "vgg.conv3_3.bn.running_var", "vgg.conv4_1.conv.weight", "vgg.conv4_1.conv.bias", "vgg.conv4_1.bn.weight", "vgg.conv4_1.bn.bias", "vgg.conv4_1.bn.running_mean", "vgg.conv4_1.bn.running_var", "vgg.conv4_2.conv.weight", "vgg.conv4_2.conv.bias", "vgg.conv4_2.bn.weight", "vgg.conv4_2.bn.bias", "vgg.conv4_2.bn.running_mean", "vgg.conv4_2.bn.running_var", "vgg.conv4_3.conv.weight", "vgg.conv4_3.conv.bias", "vgg.conv4_3.bn.weight", "vgg.conv4_3.bn.bias", "vgg.conv4_3.bn.running_mean", "vgg.conv4_3.bn.running_var", "vgg.conv5_1.conv.weight", "vgg.conv5_1.conv.bias", "vgg.conv5_1.bn.weight", "vgg.conv5_1.bn.bias", "vgg.conv5_1.bn.running_mean", "vgg.conv5_1.bn.running_var", "vgg.conv5_2.conv.weight", "vgg.conv5_2.conv.bias", "vgg.conv5_2.bn.weight", "vgg.conv5_2.bn.bias", "vgg.conv5_2.bn.running_mean", "vgg.conv5_2.bn.running_var", "vgg.conv5_3.conv.weight", "vgg.conv5_3.conv.bias", "vgg.conv5_3.bn.weight", "vgg.conv5_3.bn.bias", "vgg.conv5_3.bn.running_mean", "vgg.conv5_3.bn.running_var", "spm.assp.aspp1.atrous_conv.weight", "spm.assp.aspp1.bn.weight", "spm.assp.aspp1.bn.bias", "spm.assp.aspp1.bn.running_mean", "spm.assp.aspp1.bn.running_var", "spm.assp.aspp2.atrous_conv.weight", "spm.assp.aspp2.bn.weight", "spm.assp.aspp2.bn.bias", "spm.assp.aspp2.bn.running_mean", "spm.assp.aspp2.bn.running_var", "spm.assp.aspp3.atrous_conv.weight", "spm.assp.aspp3.bn.weight", "spm.assp.aspp3.bn.bias", "spm.assp.aspp3.bn.running_mean", "spm.assp.aspp3.bn.running_var", "spm.assp.aspp4.atrous_conv.weight", "spm.assp.aspp4.bn.weight", "spm.assp.aspp4.bn.bias", "spm.assp.aspp4.bn.running_mean", "spm.assp.aspp4.bn.running_var", "spm.assp.global_avg_pool.1.weight", "spm.assp.global_avg_pool.2.weight", "spm.assp.global_avg_pool.2.bias", "spm.assp.global_avg_pool.2.running_mean", "spm.assp.global_avg_pool.2.running_var", "spm.assp.conv1.weight", "spm.assp.bn1.weight", "spm.assp.bn1.bias", "spm.assp.bn1.running_mean", "spm.assp.bn1.running_var", "spm.can.scales.0.1.weight", "spm.can.scales.1.1.weight", "spm.can.scales.2.1.weight", "spm.can.scales.3.1.weight", "spm.can.bottleneck.weight", "spm.can.bottleneck.bias", "spm.can.weight_net.weight", "spm.can.weight_net.bias", "amp.conv1.conv.weight", "amp.conv1.conv.bias", "amp.conv1.bn.weight", "amp.conv1.bn.bias", "amp.conv1.bn.running_mean", "amp.conv1.bn.running_var", "amp.conv2.conv.weight", "amp.conv2.conv.bias", "amp.conv2.bn.weight", "amp.conv2.bn.bias", "amp.conv2.bn.running_mean", "amp.conv2.bn.running_var", "amp.conv3.conv.weight", "amp.conv3.conv.bias", "amp.conv3.bn.weight", "amp.conv3.bn.bias", "amp.conv3.bn.running_mean", "amp.conv3.bn.running_var", "amp.conv4.conv.weight", "amp.conv4.conv.bias", "amp.conv4.bn.weight", "amp.conv4.bn.bias", "amp.conv4.bn.running_mean", "amp.conv4.bn.running_var", "amp.conv5.conv.weight", "amp.conv5.conv.bias", "amp.conv5.bn.weight", "amp.conv5.bn.bias", "amp.conv5.bn.running_mean", "amp.conv5.bn.running_var", "amp.conv6.conv.weight", "amp.conv6.conv.bias", "amp.conv6.bn.weight", "amp.conv6.bn.bias", "amp.conv6.bn.running_mean", "amp.conv6.bn.running_var", "amp.conv7.conv.weight", "amp.conv7.conv.bias", "amp.conv7.bn.weight", "amp.conv7.bn.bias", "amp.conv7.bn.running_mean", "amp.conv7.bn.running_var", "dmp.conv1.conv.weight", "dmp.conv1.conv.bias", "dmp.conv1.bn.weight", "dmp.conv1.bn.bias", "dmp.conv1.bn.running_mean", "dmp.conv1.bn.running_var", "dmp.conv2.conv.weight", "dmp.conv2.conv.bias", "dmp.conv2.bn.weight", "dmp.conv2.bn.bias", "dmp.conv2.bn.running_mean", "dmp.conv2.bn.running_var", "dmp.conv3.conv.weight", "dmp.conv3.conv.bias", "dmp.conv3.bn.weight", "dmp.conv3.bn.bias", "dmp.conv3.bn.running_mean", "dmp.conv3.bn.running_var", "dmp.conv4.conv.weight", "dmp.conv4.conv.bias", "dmp.conv4.bn.weight", "dmp.conv4.bn.bias", "dmp.conv4.bn.running_mean", "dmp.conv4.bn.running_var", "dmp.conv5.conv.weight", "dmp.conv5.conv.bias", "dmp.conv5.bn.weight", "dmp.conv5.bn.bias", "dmp.conv5.bn.running_mean", "dmp.conv5.bn.running_var", "dmp.conv6.conv.weight", "dmp.conv6.conv.bias", "dmp.conv6.bn.weight", "dmp.conv6.bn.bias", "dmp.conv6.bn.running_mean", "dmp.conv6.bn.running_var", "dmp.conv7.conv.weight", "dmp.conv7.conv.bias", "dmp.conv7.bn.weight", "dmp.conv7.bn.bias", "dmp.conv7.bn.running_mean", "dmp.conv7.bn.running_var", "conv_att.conv.weight", "conv_att.conv.bias", "conv_att.bn.weight", "conv_att.bn.bias", "conv_att.bn.running_mean", "conv_att.bn.running_var", "conv_out.conv.weight", "conv_out.conv.bias", "conv_out.bn.weight", "conv_out.bn.bias", "conv_out.bn.running_mean", "conv_out.bn.running_var". Unexpected key(s) in state_dict: "epoch", "model", "optimizer", "mae", "mse".