运行inference.py时是不是必须确保用于生成检查点的模型架构与尝试加载检查点的模型架构完全相同。如果修改了网络来训练后的模型,运行inference.py检测会报错:Traceback (most recent call last):
File "E:\Multi-label-Sewer-Classification-main\inference.py", line 187, in
run_inference(args)
File "E:\Multi-label-Sewer-Classification-main\inference.py", line 129, in run_inference
model.load_state_dict(updated_state_dict)
File "C:\ProgramData\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResNet:
Unexpected key(s) in state_dict: "attention.in_proj_weight", "attention.in_proj_bias", "attention.out_proj.weight", "attention.out_proj.bias", "conv_reduce.weight", "conv_reduce.bias", "features.0.weight", "features.0.bias", "classifier.weight", "classifier.bias".
size mismatch for conv1.weight: copying a param with shape torch.Size([64, 11, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).
运行inference.py时是不是必须确保用于生成检查点的模型架构与尝试加载检查点的模型架构完全相同。如果修改了网络来训练后的模型,运行inference.py检测会报错:Traceback (most recent call last): File "E:\Multi-label-Sewer-Classification-main\inference.py", line 187, in
run_inference(args)
File "E:\Multi-label-Sewer-Classification-main\inference.py", line 129, in run_inference
model.load_state_dict(updated_state_dict)
File "C:\ProgramData\anaconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ResNet:
Unexpected key(s) in state_dict: "attention.in_proj_weight", "attention.in_proj_bias", "attention.out_proj.weight", "attention.out_proj.bias", "conv_reduce.weight", "conv_reduce.bias", "features.0.weight", "features.0.bias", "classifier.weight", "classifier.bias".
size mismatch for conv1.weight: copying a param with shape torch.Size([64, 11, 7, 7]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).