wdhudiekou / UMF-CMGR

[IJCAI2022 Oral] Unsupervised Misaligned Infrared and Visible Image Fusion via Cross-Modality Image Generation and Registration
MIT License
180 stars 18 forks source link

BUG:KeyError: 'spatial_transform.grid' #19

Open hcl013 opened 2 years ago

hcl013 commented 2 years ago

Hi teacher,When I train train_reg_fusion.py,I got an error,KeyError: 'spatial_transform.grid'. The information about this error is as following: Loading pre-trained RegNet checkpoint ../reg_0280.pth Traceback (most recent call last): File "E:/Study/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 218, in main(args, visdom) File "E:/Study/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 107, in main RegNet.load_state_dict(state) File "E:/Study/fusing/UMF-CMGR/models/deformable_net.py", line 74, in load_state_dict state_dict.pop('spatial_transform.grid') KeyError: 'spatial_transform.grid'

Can you help me? T-T

My anaconda env is as follow: Kornia 0.5.11 pytorch 1.6.0 CUDA 10.2 opencv-contrib-python 3.4.2.16 visdom 0.1.5 torchvision 0.7.0

wdhudiekou commented 2 years ago
    del state_dict['net']['spatial_transform_f.grid']
    del state_dict['net']['spatial_transform.grid']
    print("删除成功")
    # print(state_dict)
    # state_dict.pop('spatial_transform_f.grid')
    # state_dict.pop('spatial_transform.grid')

you can try it.

hcl013 commented 2 years ago

oh,the problem solved, but got a new error T-T

Loading pre-trained FuseNet checkpoint fus_0280.pth Traceback (most recent call last): File "E:/Study/DeepLearn/DPlearn/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 218, in main(args, visdom) File "E:/Study/DeepLearn/DPlearn/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 115, in main FuseNet.load_state_dict(state) File "C:\ProgramData\Anaconda3\envs\umi\lib\site-packages\torch\nn\modules\module.py", line 1045, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for FusionNet: Missing key(s) in state_dict: "conv1_1.0.weight", "conv1_1.0.bias", "conv2_1.0.weight", "conv2_1.0.bias", "ir_path.0.dense_layers.0.conv.weight", "ir_path.0.dense_layers.1.conv.weight", "ir_path.0.dense_layers.2.conv.weight", "ir_path.0.conv_1x1.weight", "vi_path.0.dense_layers.0.conv.weight", "vi_path.0.dense_layers.1.conv.weight", "vi_path.0.dense_layers.2.conv.weight", "vi_path.0.conv_1x1.weight", "fuse.query_conv.weight", "fuse.query_conv.bias", "fuse.key_conv.weight", "fuse.key_conv.bias", "fuse.gamma1.weight", "fuse.gamma1.bias", "fuse.gamma2.weight", "fuse.gamma2.bias", "fuse_res.weight", "fuse_res.bias", "out_conv.weight", "out_conv.bias". Unexpected key(s) in state_dict: "net", "opt".

123abxc commented 10 months ago

image 请问您最后怎么解决的 我也被这个问题困扰

kuailexiaohunzi commented 8 months ago

oh,the problem solved, but got a new error T-T

Loading pre-trained FuseNet checkpoint fus_0280.pth Traceback (most recent call last): File "E:/Study/DeepLearn/DPlearn/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 218, in main(args, visdom) File "E:/Study/DeepLearn/DPlearn/fusing/UMF-CMGR/Trainer/train_reg_fusion.py", line 115, in main FuseNet.load_state_dict(state) File "C:\ProgramData\Anaconda3\envs\umi\lib\site-packages\torch\nn\modules\module.py", line 1045, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for FusionNet: Missing key(s) in state_dict: "conv1_1.0.weight", "conv1_1.0.bias", "conv2_1.0.weight", "conv2_1.0.bias", "ir_path.0.dense_layers.0.conv.weight", "ir_path.0.dense_layers.1.conv.weight", "ir_path.0.dense_layers.2.conv.weight", "ir_path.0.conv_1x1.weight", "vi_path.0.dense_layers.0.conv.weight", "vi_path.0.dense_layers.1.conv.weight", "vi_path.0.dense_layers.2.conv.weight", "vi_path.0.conv_1x1.weight", "fuse.query_conv.weight", "fuse.query_conv.bias", "fuse.key_conv.weight", "fuse.key_conv.bias", "fuse.gamma1.weight", "fuse.gamma1.bias", "fuse.gamma2.weight", "fuse.gamma2.bias", "fuse_res.weight", "fuse_res.bias", "out_conv.weight", "out_conv.bias". Unexpected key(s) in state_dict: "net", "opt".

你好,请问你这个问题解决了吗?我也遇到了这个问题,想请教一下

yidamyth commented 20 hours ago

用下面的替代对应的原始代码可以解决

# print("===> Building model")
# net = DeformableNet().to(device)
#
# print("===> loading trained model '{}'".format(args.ckpt))
# model_state_dict = torch.load(args.ckpt)
# net.load_state_dict(model_state_dict)
#
#
# print("===> Starting Testing")
# test(net, test_data_loader, args.dst, device)

print("===> Building model")
net = DeformableNet().to(device)

print("===> loading trained model '{}'".format(args.ckpt))
# 加载权重文件
checkpoint = torch.load(args.ckpt)

# 如果权重文件包含 'state_dict' 键,提取其内容
if 'state_dict' in checkpoint:
    model_state_dict = checkpoint['state_dict']
else:
    model_state_dict = checkpoint  # 直接使用整个 checkpoint 作为 state_dict

# 清理可能导致加载问题的键(如 'spatial_transform.grid')
model_state_dict = {k: v for k, v in model_state_dict.items() if 'spatial_transform.grid' not in k}

# 加载模型权重,使用 strict=False 忽略不匹配的键
try:
    net.load_state_dict(model_state_dict, strict=False)
    print("Model weights loaded successfully.")
except Exception as e:
    print(f"Error loading model weights: {e}")

print("===> Starting Testing")
# 启动测试
test(net, test_data_loader, args.dst, device)