Open hcl013 opened 2 years ago
del state_dict['net']['spatial_transform_f.grid']
del state_dict['net']['spatial_transform.grid']
print("删除成功")
# print(state_dict)
# state_dict.pop('spatial_transform_f.grid')
# state_dict.pop('spatial_transform.grid')
you can try it.
oh,the problem solved, but got a new error T-T
Loading pre-trained FuseNet checkpoint fus_0280.pth
Traceback (most recent call last):
File "E:/Study/DeepLearn/DPlearn/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 218, in
请问您最后怎么解决的 我也被这个问题困扰
oh,the problem solved, but got a new error T-T
Loading pre-trained FuseNet checkpoint fus_0280.pth Traceback (most recent call last): File "E:/Study/DeepLearn/DPlearn/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 218, in main(args, visdom) File "E:/Study/DeepLearn/DPlearn/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 115, in main FuseNet.load_state_dict(state) File "C:\ProgramData\Anaconda3\envs\umi\lib\site-packages\torch\nn\modules\module.py", line 1045, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for FusionNet: Missing key(s) in state_dict: "conv1_1.0.weight", "conv1_1.0.bias", "conv2_1.0.weight", "conv2_1.0.bias", "ir_path.0.dense_layers.0.conv.weight", "ir_path.0.dense_layers.1.conv.weight", "ir_path.0.dense_layers.2.conv.weight", "ir_path.0.conv_1x1.weight", "vi_path.0.dense_layers.0.conv.weight", "vi_path.0.dense_layers.1.conv.weight", "vi_path.0.dense_layers.2.conv.weight", "vi_path.0.conv_1x1.weight", "fuse.query_conv.weight", "fuse.query_conv.bias", "fuse.key_conv.weight", "fuse.key_conv.bias", "fuse.gamma1.weight", "fuse.gamma1.bias", "fuse.gamma2.weight", "fuse.gamma2.bias", "fuse_res.weight", "fuse_res.bias", "out_conv.weight", "out_conv.bias". Unexpected key(s) in state_dict: "net", "opt".
你好,请问你这个问题解决了吗?我也遇到了这个问题,想请教一下
用下面的替代对应的原始代码可以解决
# print("===> Building model")
# net = DeformableNet().to(device)
#
# print("===> loading trained model '{}'".format(args.ckpt))
# model_state_dict = torch.load(args.ckpt)
# net.load_state_dict(model_state_dict)
#
#
# print("===> Starting Testing")
# test(net, test_data_loader, args.dst, device)
print("===> Building model")
net = DeformableNet().to(device)
print("===> loading trained model '{}'".format(args.ckpt))
# 加载权重文件
checkpoint = torch.load(args.ckpt)
# 如果权重文件包含 'state_dict' 键,提取其内容
if 'state_dict' in checkpoint:
model_state_dict = checkpoint['state_dict']
else:
model_state_dict = checkpoint # 直接使用整个 checkpoint 作为 state_dict
# 清理可能导致加载问题的键(如 'spatial_transform.grid')
model_state_dict = {k: v for k, v in model_state_dict.items() if 'spatial_transform.grid' not in k}
# 加载模型权重,使用 strict=False 忽略不匹配的键
try:
net.load_state_dict(model_state_dict, strict=False)
print("Model weights loaded successfully.")
except Exception as e:
print(f"Error loading model weights: {e}")
print("===> Starting Testing")
# 启动测试
test(net, test_data_loader, args.dst, device)
Hi teacher,When I train train_reg_fusion.py,I got an error,KeyError: 'spatial_transform.grid'. The information about this error is as following: Loading pre-trained RegNet checkpoint ../reg_0280.pth Traceback (most recent call last): File "E:/Study/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 218, in
main(args, visdom)
File "E:/Study/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 107, in main
RegNet.load_state_dict(state)
File "E:/Study/fusing/UMF-CMGR/models/deformable_net.py", line 74, in load_state_dict
state_dict.pop('spatial_transform.grid')
KeyError: 'spatial_transform.grid'
Can you help me? T-T
My anaconda env is as follow: Kornia 0.5.11 pytorch 1.6.0 CUDA 10.2 opencv-contrib-python 3.4.2.16 visdom 0.1.5 torchvision 0.7.0