hongyuanyu / SPAN

Swift Parameter-free Attention Network for Efficient Super-Resolution
Apache License 2.0
135 stars 6 forks source link

try to run the model with pre-trained file but get error #10

Open tomcan2015 opened 2 months ago

tomcan2015 commented 2 months ago

With the code: device = torch.device(args.device) net = SPAN(3, 3, upscale=2, feature_channels=48).to(device) loaded = torch.load(args.model_file, map_location=device) net.load_state_dict(loaded, strict=True)

I got this error when loading the spanx2_ch48.pth file: File "/home/amd/anaconda3/envs/SPAN/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SPAN: Missing key(s) in state_dict: "conv_1.sk.weight", "conv_1.sk.bias", "conv_1.conv.0.weight", "conv_1.conv.0.bias", "conv_1.conv.1.weight", "conv_1.conv.1.bias", "conv_1.conv.2.weight", "conv_1.conv.2.bias", "conv_1.eval_conv.weight", "conv_1.eval_conv.bias", "block_1.c1_r.sk.weight", "block_1.c1_r.sk.bias", "block_1.c1_r.conv.0.weight", "block_1.c1_r.conv.0.bias", "block_1.c1_r.conv.1.weight", "block_1.c1_r.conv.1.bias", "block_1.c1_r.conv.2.weight", "block_1.c1_r.conv.2.bias", "block_1.c1_r.eval_conv.weight", "block_1.c1_r.eval_conv.bias", "block_1.c2_r.sk.weight", "block_1.c2_r.sk.bias", "block_1.c2_r.conv.0.weight", "block_1.c2_r.conv.0.bias", "block_1.c2_r.conv.1.weight", "block_1.c2_r.conv.1.bias", "block_1.c2_r.conv.2.weight", "block_1.c2_r.conv.2.bias", "block_1.c2_r.eval_conv.weight", "block_1.c2_r.eval_conv.bias", "block_1.c3_r.sk.weight", "block_1.c3_r.sk.bias", "block_1.c3_r.conv.0.weight", "block_1.c3_r.conv.0.bias", "block_1.c3_r.conv.1.weight", "block_1.c3_r.conv.1.bias", "block_1.c3_r.conv.2.weight", "block_1.c3_r.conv.2.bias", "block_1.c3_r.eval_conv.weight", "block_1.c3_r.eval_conv.bias", "block_2.c1_r.sk.weight", "block_2.c1_r.sk.bias", "block_2.c1_r.conv.0.weight", "block_2.c1_r.conv.0.bias", "block_2.c1_r.conv.1.weight", "block_2.c1_r.conv.1.bias", "block_2.c1_r.conv.2.weight", "block_2.c1_r.conv.2.bias", "block_2.c1_r.eval_conv.weight", "block_2.c1_r.eval_conv.bias", "block_2.c2_r.sk.weight", "block_2.c2_r.sk.bias", "block_2.c2_r.conv.0.weight", "block_2.c2_r.conv.0.bias", "block_2.c2_r.conv.1.weight", "block_2.c2_r.conv.1.bias", "block_2.c2_r.conv.2.weight", "block_2.c2_r.conv.2.bias", "block_2.c2_r.eval_conv.weight", "block_2.c2_r.eval_conv.bias", "block_2.c3_r.sk.weight", "block_2.c3_r.sk.bias", "block_2.c3_r.conv.0.weight", "block_2.c3_r.conv.0.bias", "block_2.c3_r.conv.1.weight", "block_2.c3_r.conv.1.bias", "block_2.c3_r.conv.2.weight", "block_2.c3_r.conv.2.bias", "block_2.c3_r.eval_conv.weight", "block_2.c3_r.eval_conv.bias", "block_3.c1_r.sk.weight", "block_3.c1_r.sk.bias", "block_3.c1_r.conv.0.weight", "block_3.c1_r.conv.0.bias", "block_3.c1_r.conv.1.weight", "block_3.c1_r.conv.1.bias", "block_3.c1_r.conv.2.weight", "block_3.c1_r.conv.2.bias", "block_3.c1_r.eval_conv.weight", "block_3.c1_r.eval_conv.bias", "block_3.c2_r.sk.weight", "block_3.c2_r.sk.bias", "block_3.c2_r.conv.0.weight", "block_3.c2_r.conv.0.bias", "block_3.c2_r.conv.1.weight", "block_3.c2_r.conv.1.bias", "block_3.c2_r.conv.2.weight", "block_3.c2_r.conv.2.bias", "block_3.c2_r.eval_conv.weight", "block_3.c2_r.eval_conv.bias", "block_3.c3_r.sk.weight", "block_3.c3_r.sk.bias", "block_3.c3_r.conv.0.weight", "block_3.c3_r.conv.0.bias", "block_3.c3_r.conv.1.weight", "block_3.c3_r.conv.1.bias", "block_3.c3_r.conv.2.weight", "block_3.c3_r.conv.2.bias", "block_3.c3_r.eval_conv.weight", "block_3.c3_r.eval_conv.bias", "block_4.c1_r.sk.weight", "block_4.c1_r.sk.bias", "block_4.c1_r.conv.0.weight", "block_4.c1_r.conv.0.bias", "block_4.c1_r.conv.1.weight", "block_4.c1_r.conv.1.bias", "block_4.c1_r.conv.2.weight", "block_4.c1_r.conv.2.bias", "block_4.c1_r.eval_conv.weight", "block_4.c1_r.eval_conv.bias", "block_4.c2_r.sk.weight", "block_4.c2_r.sk.bias", "block_4.c2_r.conv.0.weight", "block_4.c2_r.conv.0.bias", "block_4.c2_r.conv.1.weight", "block_4.c2_r.conv.1.bias", "block_4.c2_r.conv.2.weight", "block_4.c2_r.conv.2.bias", "block_4.c2_r.eval_conv.weight", "block_4.c2_r.eval_conv.bias", "block_4.c3_r.sk.weight", "block_4.c3_r.sk.bias", "block_4.c3_r.conv.0.weight", "block_4.c3_r.conv.0.bias", "block_4.c3_r.conv.1.weight", "block_4.c3_r.conv.1.bias", "block_4.c3_r.conv.2.weight", "block_4.c3_r.conv.2.bias", "block_4.c3_r.eval_conv.weight", "block_4.c3_r.eval_conv.bias", "block_5.c1_r.sk.weight", "block_5.c1_r.sk.bias", "block_5.c1_r.conv.0.weight", "block_5.c1_r.conv.0.bias", "block_5.c1_r.conv.1.weight", "block_5.c1_r.conv.1.bias", "block_5.c1_r.conv.2.weight", "block_5.c1_r.conv.2.bias", "block_5.c1_r.eval_conv.weight", "block_5.c1_r.eval_conv.bias", "block_5.c2_r.sk.weight", "block_5.c2_r.sk.bias", "block_5.c2_r.conv.0.weight", "block_5.c2_r.conv.0.bias", "block_5.c2_r.conv.1.weight", "block_5.c2_r.conv.1.bias", "block_5.c2_r.conv.2.weight", "block_5.c2_r.conv.2.bias", "block_5.c2_r.eval_conv.weight", "block_5.c2_r.eval_conv.bias", "block_5.c3_r.sk.weight", "block_5.c3_r.sk.bias", "block_5.c3_r.conv.0.weight", "block_5.c3_r.conv.0.bias", "block_5.c3_r.conv.1.weight", "block_5.c3_r.conv.1.bias", "block_5.c3_r.conv.2.weight", "block_5.c3_r.conv.2.bias", "block_5.c3_r.eval_conv.weight", "block_5.c3_r.eval_conv.bias", "block_6.c1_r.sk.weight", "block_6.c1_r.sk.bias", "block_6.c1_r.conv.0.weight", "block_6.c1_r.conv.0.bias", "block_6.c1_r.conv.1.weight", "block_6.c1_r.conv.1.bias", "block_6.c1_r.conv.2.weight", "block_6.c1_r.conv.2.bias", "block_6.c1_r.eval_conv.weight", "block_6.c1_r.eval_conv.bias", "block_6.c2_r.sk.weight", "block_6.c2_r.sk.bias", "block_6.c2_r.conv.0.weight", "block_6.c2_r.conv.0.bias", "block_6.c2_r.conv.1.weight", "block_6.c2_r.conv.1.bias", "block_6.c2_r.conv.2.weight", "block_6.c2_r.conv.2.bias", "block_6.c2_r.eval_conv.weight", "block_6.c2_r.eval_conv.bias", "block_6.c3_r.sk.weight", "block_6.c3_r.sk.bias", "block_6.c3_r.conv.0.weight", "block_6.c3_r.conv.0.bias", "block_6.c3_r.conv.1.weight", "block_6.c3_r.conv.1.bias", "block_6.c3_r.conv.2.weight", "block_6.c3_r.conv.2.bias", "block_6.c3_r.eval_conv.weight", "block_6.c3_r.eval_conv.bias", "conv_cat.weight", "conv_cat.bias", "conv_2.sk.weight", "conv_2.sk.bias", "conv_2.conv.0.weight", "conv_2.conv.0.bias", "conv_2.conv.1.weight", "conv_2.conv.1.bias", "conv_2.conv.2.weight", "conv_2.conv.2.bias", "conv_2.eval_conv.weight", "conv_2.eval_conv.bias", "upsampler.0.weight", "upsampler.0.bias". Unexpected key(s) in state_dict: "params", "params_ema".

any idea?

tomcan2015 commented 2 months ago

loaded = torch.load(args.model_file)['params'] instead of: loaded = torch.load(args.model_file, map_location=device) can fix the problem. The input/output tensors has data range 0~1.

ReBenDish commented 1 month ago

Hello!How to get the pre-train modle? I find the Google Driver link is error.