ming053l / DRCT

Accepted by New Trends in Image Restoration and Enhancement workshop (NTIRE), in conjunction with CVPR 2024.
MIT License
207 stars 17 forks source link

Real-DRCT-GAN_MSE_Model for inference #20

Open zhaoyong-li opened 4 months ago

zhaoyong-li commented 4 months ago

hi, thank you for providing such great work! But I have a question when using Real-DRCT-GAN_MSE_Model for inference, is it not possible to use the inference file directly for this model weight? When I try, the following problem occurs : "RuntimeError: Error(s) in loading state_dict for DRCT". I don't know if there is something wrong with my operation and would appreciate your answer!

Skyninth commented 3 months ago

same problem! I also try to run inference.py by Real-DRCT-GAN_MSE_Model, but get the RuntimError:

Traceback (most recent call last): File "inference.py", line 111, in main() File "inference.py", line 37, in main model.load_state_dict(torch.load(args.model_path)['params'], strict=True) File "/ProjectRoot/drct/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DRCT: Missing key(s) in state_dict: "layers.6.swin1.norm1.weight", "layers.6.swin1.norm1.bias", "layers.6.swin1.attn.relative_position_bias_table", "layers.6.swin1.attn.relative_position_index", "layers.6.swin1.attn.qkv.weight", "layers.6.swin1.attn.qkv.bias", "layers.6.swin1.attn.proj.weight", "layers.6.swin1.attn.proj.bias", "layers.6.swin1.norm2.weight", "layers.6.swin1.norm2.bias", "layers.6.swin1.mlp.fc1.weight", "layers.6.swin1.mlp.fc1.bias", "layers.6.swin1.mlp.fc2.weight", "layers.6.swin1.mlp.fc2.bias", "layers.6.adjust1.weight", "layers.6.adjust1.bias", "layers.6.swin2.attn_mask", "layers.6.swin2.norm1.weight", "layers.6.swin2.norm1.bias", "layers.6.swin2.attn.relative_position_bias_table", "layers.6.swin2.attn.relative_position_index", "layers.6.swin2.attn.qkv.weight", "layers.6.swin2.attn.qkv.bias", "layers.6.swin2.attn.proj.weight", "layers.6.swin2.attn.proj.bias", "layers.6.swin2.norm2.weight", "layers.6.swin2.norm2.bias", "layers.6.swin2.mlp.fc1.weight", "layers.6.swin2.mlp.fc1.bias", "layers.6.swin2.mlp.fc2.weight", "layers.6.swin2.mlp.fc2.bias", "layers.6.adjust2.weight", "layers.6.adjust2.bias", "layers.6.swin3.norm1.weight", "layers.6.swin3.norm1.bias", "layers.6.swin3.attn.relative_position_bias_table", "layers.6.swin3.attn.relative_position_index", "layers.6.swin3.attn.qkv.weight", "layers.6.swin3.attn.qkv.bias", "layers.6.swin3.attn.proj.weight", "layers.6.swin3.attn.proj.bias", "layers.6.swin3.norm2.weight", "layers.6.swin3.norm2.bias", "layers.6.swin3.mlp.fc1.weight", "layers.6.swin3.mlp.fc1.bias", "layers.6.swin3.mlp.fc2.weight", "layers.6.swin3.mlp.fc2.bias", "layers.6.adjust3.weight", "layers.6.adjust3.bias", "layers.6.swin4.attn_mask", "layers.6.swin4.norm1.weight", "layers.6.swin4.norm1.bias", "layers.6.swin4.attn.relative_position_bias_table", "layers.6.swin4.attn.relative_position_index", "layers.6.swin4.attn.qkv.weight", "layers.6.swin4.attn.qkv.bias", "layers.6.swin4.attn.proj.weight", "layers.6.swin4.attn.proj.bias", "layers.6.swin4.norm2.weight", "layers.6.swin4.norm2.bias", "layers.6.swin4.mlp.fc1.weight", "layers.6.swin4.mlp.fc1.bias", "layers.6.swin4.mlp.fc2.weight", "layers.6.swin4.mlp.fc2.bias", "layers.6.adjust4.weight", "layers.6.adjust4.bias", "layers.6.swin5.norm1.weight", "layers.6.swin5.norm1.bias", "layers.6.swin5.attn.relative_position_bias_table", "layers.6.swin5.attn.relative_position_index", "layers.6.swin5.attn.qkv.weight", "layers.6.swin5.attn.qkv.bias", "layers.6.swin5.attn.proj.weight", "layers.6.swin5.attn.proj.bias", "layers.6.swin5.norm2.weight", "layers.6.swin5.norm2.bias", "layers.6.swin5.mlp.fc1.weight", "layers.6.swin5.mlp.fc1.bias", "layers.6.swin5.mlp.fc2.weight", "layers.6.swin5.mlp.fc2.bias", "layers.6.adjust5.weight", "layers.6.adjust5.bias", "layers.7.swin1.norm1.weight", "layers.7.swin1.norm1.bias", "layers.7.swin1.attn.relative_position_bias_table", "layers.7.swin1.attn.relative_position_index", "layers.7.swin1.attn.qkv.weight", "layers.7.swin1.attn.qkv.bias", "layers.7.swin1.attn.proj.weight", "layers.7.swin1.attn.proj.bias", "layers.7.swin1.norm2.weight", "layers.7.swin1.norm2.bias", "layers.7.swin1.mlp.fc1.weight", "layers.7.swin1.mlp.fc1.bias", "layers.7.swin1.mlp.fc2.weight", "layers.7.swin1.mlp.fc2.bias", "layers.7.adjust1.weight", "layers.7.adjust1.bias", "layers.7.swin2.attn_mask", "layers.7.swin2.norm1.weight", "layers.7.swin2.norm1.bias", "layers.7.swin2.attn.relative_position_bias_table", "layers.7.swin2.attn.relative_position_index", "layers.7.swin2.attn.qkv.weight", "layers.7.swin2.attn.qkv.bias", "layers.7.swin2.attn.proj.weight", "layers.7.swin2.attn.proj.bias", "layers.7.swin2.norm2.weight", "layers.7.swin2.norm2.bias", "layers.7.swin2.mlp.fc1.weight", "layers.7.swin2.mlp.fc1.bias", "layers.7.swin2.mlp.fc2.weight", "layers.7.swin2.mlp.fc2.bias", "layers.7.adjust2.weight", "layers.7.adjust2.bias", "layers.7.swin3.norm1.weight", "layers.7.swin3.norm1.bias", "layers.7.swin3.attn.relative_position_bias_table", "layers.7.swin3.attn.relative_position_index", "layers.7.swin3.attn.qkv.weight", "layers.7.swin3.attn.qkv.bias", "layers.7.swin3.attn.proj.weight", "layers.7.swin3.attn.proj.bias", "layers.7.swin3.norm2.weight", "layers.7.swin3.norm2.bias", "layers.7.swin3.mlp.fc1.weight", "layers.7.swin3.mlp.fc1.bias", "layers.7.swin3.mlp.fc2.weight", "layers.7.swin3.mlp.fc2.bias", "layers.7.adjust3.weight", "layers.7.adjust3.bias", "layers.7.swin4.attn_mask", "layers.7.swin4.norm1.weight", "layers.7.swin4.norm1.bias", "layers.7.swin4.attn.relative_position_bias_table", "layers.7.swin4.attn.relative_position_index", "layers.7.swin4.attn.qkv.weight", "layers.7.swin4.attn.qkv.bias", "layers.7.swin4.attn.proj.weight", "layers.7.swin4.attn.proj.bias", "layers.7.swin4.norm2.weight", "layers.7.swin4.norm2.bias", "layers.7.swin4.mlp.fc1.weight", "layers.7.swin4.mlp.fc1.bias", "layers.7.swin4.mlp.fc2.weight", "layers.7.swin4.mlp.fc2.bias", "layers.7.adjust4.weight", "layers.7.adjust4.bias", "layers.7.swin5.norm1.weight", "layers.7.swin5.norm1.bias", "layers.7.swin5.attn.relative_position_bias_table", "layers.7.swin5.attn.relative_position_index", "layers.7.swin5.attn.qkv.weight", "layers.7.swin5.attn.qkv.bias", "layers.7.swin5.attn.proj.weight", "layers.7.swin5.attn.proj.bias", "layers.7.swin5.norm2.weight", "layers.7.swin5.norm2.bias", "layers.7.swin5.mlp.fc1.weight", "layers.7.swin5.mlp.fc1.bias", "layers.7.swin5.mlp.fc2.weight", "layers.7.swin5.mlp.fc2.bias", "layers.7.adjust5.weight", "layers.7.adjust5.bias", "layers.8.swin1.norm1.weight", "layers.8.swin1.norm1.bias", "layers.8.swin1.attn.relative_position_bias_table", "layers.8.swin1.attn.relative_position_index", "layers.8.swin1.attn.qkv.weight", "layers.8.swin1.attn.qkv.bias", "layers.8.swin1.attn.proj.weight", "layers.8.swin1.attn.proj.bias", "layers.8.swin1.norm2.weight", "layers.8.swin1.norm2.bias", "layers.8.swin1.mlp.fc1.weight", "layers.8.swin1.mlp.fc1.bias", "layers.8.swin1.mlp.fc2.weight", "layers.8.swin1.mlp.fc2.bias", "layers.8.adjust1.weight", "layers.8.adjust1.bias", "layers.8.swin2.attn_mask", "layers.8.swin2.norm1.weight", "layers.8.swin2.norm1.bias", "layers.8.swin2.attn.relative_position_bias_table", "layers.8.swin2.attn.relative_position_index", "layers.8.swin2.attn.qkv.weight", "layers.8.swin2.attn.qkv.bias", "layers.8.swin2.attn.proj.weight", "layers.8.swin2.attn.proj.bias", "layers.8.swin2.norm2.weight", "layers.8.swin2.norm2.bias", "layers.8.swin2.mlp.fc1.weight", "layers.8.swin2.mlp.fc1.bias", "layers.8.swin2.mlp.fc2.weight", "layers.8.swin2.mlp.fc2.bias", "layers.8.adjust2.weight", "layers.8.adjust2.bias", "layers.8.swin3.norm1.weight", "layers.8.swin3.norm1.bias", "layers.8.swin3.attn.relative_position_bias_table", "layers.8.swin3.attn.relative_position_index", "layers.8.swin3.attn.qkv.weight", "layers.8.swin3.attn.qkv.bias", "layers.8.swin3.attn.proj.weight", "layers.8.swin3.attn.proj.bias", "layers.8.swin3.norm2.weight", "layers.8.swin3.norm2.bias", "layers.8.swin3.mlp.fc1.weight", "layers.8.swin3.mlp.fc1.bias", "layers.8.swin3.mlp.fc2.weight", "layers.8.swin3.mlp.fc2.bias", "layers.8.adjust3.weight", "layers.8.adjust3.bias", "layers.8.swin4.attn_mask", "layers.8.swin4.norm1.weight", "layers.8.swin4.norm1.bias", "layers.8.swin4.attn.relative_position_bias_table", "layers.8.swin4.attn.relative_position_index", "layers.8.swin4.attn.qkv.weight", "layers.8.swin4.attn.qkv.bias", "layers.8.swin4.attn.proj.weight", "layers.8.swin4.attn.proj.bias", "layers.8.swin4.norm2.weight", "layers.8.swin4.norm2.bias", "layers.8.swin4.mlp.fc1.weight", "layers.8.swin4.mlp.fc1.bias", "layers.8.swin4.mlp.fc2.weight", "layers.8.swin4.mlp.fc2.bias", "layers.8.adjust4.weight", "layers.8.adjust4.bias", "layers.8.swin5.norm1.weight", "layers.8.swin5.norm1.bias", "layers.8.swin5.attn.relative_position_bias_table", "layers.8.swin5.attn.relative_position_index", "layers.8.swin5.attn.qkv.weight", "layers.8.swin5.attn.qkv.bias", "layers.8.swin5.attn.proj.weight", "layers.8.swin5.attn.proj.bias", "layers.8.swin5.norm2.weight", "layers.8.swin5.norm2.bias", "layers.8.swin5.mlp.fc1.weight", "layers.8.swin5.mlp.fc1.bias", "layers.8.swin5.mlp.fc2.weight", "layers.8.swin5.mlp.fc2.bias", "layers.8.adjust5.weight", "layers.8.adjust5.bias", "layers.9.swin1.norm1.weight", "layers.9.swin1.norm1.bias", "layers.9.swin1.attn.relative_position_bias_table", "layers.9.swin1.attn.relative_position_index", "layers.9.swin1.attn.qkv.weight", "layers.9.swin1.attn.qkv.bias", "layers.9.swin1.attn.proj.weight", "layers.9.swin1.attn.proj.bias", "layers.9.swin1.norm2.weight", "layers.9.swin1.norm2.bias", "layers.9.swin1.mlp.fc1.weight", "layers.9.swin1.mlp.fc1.bias", "layers.9.swin1.mlp.fc2.weight", "layers.9.swin1.mlp.fc2.bias", "layers.9.adjust1.weight", "layers.9.adjust1.bias", "layers.9.swin2.attn_mask", "layers.9.swin2.norm1.weight", "layers.9.swin2.norm1.bias", "layers.9.swin2.attn.relative_position_bias_table", "layers.9.swin2.attn.relative_position_index", "layers.9.swin2.attn.qkv.weight", "layers.9.swin2.attn.qkv.bias", "layers.9.swin2.attn.proj.weight", "layers.9.swin2.attn.proj.bias", "layers.9.swin2.norm2.weight", "layers.9.swin2.norm2.bias", "layers.9.swin2.mlp.fc1.weight", "layers.9.swin2.mlp.fc1.bias", "layers.9.swin2.mlp.fc2.weight", "layers.9.swin2.mlp.fc2.bias", "layers.9.adjust2.weight", "layers.9.adjust2.bias", "layers.9.swin3.norm1.weight", "layers.9.swin3.norm1.bias", "layers.9.swin3.attn.relative_position_bias_table", "layers.9.swin3.attn.relative_position_index", "layers.9.swin3.attn.qkv.weight", "layers.9.swin3.attn.qkv.bias", "layers.9.swin3.attn.proj.weight", "layers.9.swin3.attn.proj.bias", "layers.9.swin3.norm2.weight", "layers.9.swin3.norm2.bias", "layers.9.swin3.mlp.fc1.weight", "layers.9.swin3.mlp.fc1.bias", "layers.9.swin3.mlp.fc2.weight", "layers.9.swin3.mlp.fc2.bias", "layers.9.adjust3.weight", "layers.9.adjust3.bias", "layers.9.swin4.attn_mask", "layers.9.swin4.norm1.weight", "layers.9.swin4.norm1.bias", "layers.9.swin4.attn.relative_position_bias_table", "layers.9.swin4.attn.relative_position_index", "layers.9.swin4.attn.qkv.weight", "layers.9.swin4.attn.qkv.bias", "layers.9.swin4.attn.proj.weight", "layers.9.swin4.attn.proj.bias", "layers.9.swin4.norm2.weight", "layers.9.swin4.norm2.bias", "layers.9.swin4.mlp.fc1.weight", "layers.9.swin4.mlp.fc1.bias", "layers.9.swin4.mlp.fc2.weight", "layers.9.swin4.mlp.fc2.bias", "layers.9.adjust4.weight", "layers.9.adjust4.bias", "layers.9.swin5.norm1.weight", "layers.9.swin5.norm1.bias", "layers.9.swin5.attn.relative_position_bias_table", "layers.9.swin5.attn.relative_position_index", "layers.9.swin5.attn.qkv.weight", "layers.9.swin5.attn.qkv.bias", "layers.9.swin5.attn.proj.weight", "layers.9.swin5.attn.proj.bias", "layers.9.swin5.norm2.weight", "layers.9.swin5.norm2.bias", "layers.9.swin5.mlp.fc1.weight", "layers.9.swin5.mlp.fc1.bias", "layers.9.swin5.mlp.fc2.weight", "layers.9.swin5.mlp.fc2.bias", "layers.9.adjust5.weight", "layers.9.adjust5.bias", "layers.10.swin1.norm1.weight", "layers.10.swin1.norm1.bias", "layers.10.swin1.attn.relative_position_bias_table", "layers.10.swin1.attn.relative_position_index", "layers.10.swin1.attn.qkv.weight", "layers.10.swin1.attn.qkv.bias", "layers.10.swin1.attn.proj.weight", "layers.10.swin1.attn.proj.bias", "layers.10.swin1.norm2.weight", "layers.10.swin1.norm2.bias", "layers.10.swin1.mlp.fc1.weight", "layers.10.swin1.mlp.fc1.bias", "layers.10.swin1.mlp.fc2.weight", "layers.10.swin1.mlp.fc2.bias", "layers.10.adjust1.weight", "layers.10.adjust1.bias", "layers.10.swin2.attn_mask", "layers.10.swin2.norm1.weight", "layers.10.swin2.norm1.bias", "layers.10.swin2.attn.relative_position_bias_table", "layers.10.swin2.attn.relative_position_index", "layers.10.swin2.attn.qkv.weight", "layers.10.swin2.attn.qkv.bias", "layers.10.swin2.attn.proj.weight", "layers.10.swin2.attn.proj.bias", "layers.10.swin2.norm2.weight", "layers.10.swin2.norm2.bias", "layers.10.swin2.mlp.fc1.weight", "layers.10.swin2.mlp.fc1.bias", "layers.10.swin2.mlp.fc2.weight", "layers.10.swin2.mlp.fc2.bias", "layers.10.adjust2.weight", "layers.10.adjust2.bias", "layers.10.swin3.norm1.weight", "layers.10.swin3.norm1.bias", "layers.10.swin3.attn.relative_position_bias_table", "layers.10.swin3.attn.relative_position_index", "layers.10.swin3.attn.qkv.weight", "layers.10.swin3.attn.qkv.bias", "layers.10.swin3.attn.proj.weight", "layers.10.swin3.attn.proj.bias", "layers.10.swin3.norm2.weight", "layers.10.swin3.norm2.bias", "layers.10.swin3.mlp.fc1.weight", "layers.10.swin3.mlp.fc1.bias", "layers.10.swin3.mlp.fc2.weight", "layers.10.swin3.mlp.fc2.bias", "layers.10.adjust3.weight", "layers.10.adjust3.bias", "layers.10.swin4.attn_mask", "layers.10.swin4.norm1.weight", "layers.10.swin4.norm1.bias", "layers.10.swin4.attn.relative_position_bias_table", "layers.10.swin4.attn.relative_position_index", "layers.10.swin4.attn.qkv.weight", "layers.10.swin4.attn.qkv.bias", "layers.10.swin4.attn.proj.weight", "layers.10.swin4.attn.proj.bias", "layers.10.swin4.norm2.weight", "layers.10.swin4.norm2.bias", "layers.10.swin4.mlp.fc1.weight", "layers.10.swin4.mlp.fc1.bias", "layers.10.swin4.mlp.fc2.weight", "layers.10.swin4.mlp.fc2.bias", "layers.10.adjust4.weight", "layers.10.adjust4.bias", "layers.10.swin5.norm1.weight", "layers.10.swin5.norm1.bias", "layers.10.swin5.attn.relative_position_bias_table", "layers.10.swin5.attn.relative_position_index", "layers.10.swin5.attn.qkv.weight", "layers.10.swin5.attn.qkv.bias", "layers.10.swin5.attn.proj.weight", "layers.10.swin5.attn.proj.bias", "layers.10.swin5.norm2.weight", "layers.10.swin5.norm2.bias", "layers.10.swin5.mlp.fc1.weight", "layers.10.swin5.mlp.fc1.bias", "layers.10.swin5.mlp.fc2.weight", "layers.10.swin5.mlp.fc2.bias", "layers.10.adjust5.weight", "layers.10.adjust5.bias", "layers.11.swin1.norm1.weight", "layers.11.swin1.norm1.bias", "layers.11.swin1.attn.relative_position_bias_table", "layers.11.swin1.attn.relative_position_index", "layers.11.swin1.attn.qkv.weight", "layers.11.swin1.attn.qkv.bias", "layers.11.swin1.attn.proj.weight", "layers.11.swin1.attn.proj.bias", "layers.11.swin1.norm2.weight", "layers.11.swin1.norm2.bias", "layers.11.swin1.mlp.fc1.weight", "layers.11.swin1.mlp.fc1.bias", "layers.11.swin1.mlp.fc2.weight", "layers.11.swin1.mlp.fc2.bias", "layers.11.adjust1.weight", "layers.11.adjust1.bias", "layers.11.swin2.attn_mask", "layers.11.swin2.norm1.weight", "layers.11.swin2.norm1.bias", "layers.11.swin2.attn.relative_position_bias_table", "layers.11.swin2.attn.relative_position_index", "layers.11.swin2.attn.qkv.weight", "layers.11.swin2.attn.qkv.bias", "layers.11.swin2.attn.proj.weight", "layers.11.swin2.attn.proj.bias", "layers.11.swin2.norm2.weight", "layers.11.swin2.norm2.bias", "layers.11.swin2.mlp.fc1.weight", "layers.11.swin2.mlp.fc1.bias", "layers.11.swin2.mlp.fc2.weight", "layers.11.swin2.mlp.fc2.bias", "layers.11.adjust2.weight", "layers.11.adjust2.bias", "layers.11.swin3.norm1.weight", "layers.11.swin3.norm1.bias", "layers.11.swin3.attn.relative_position_bias_table", "layers.11.swin3.attn.relative_position_index", "layers.11.swin3.attn.qkv.weight", "layers.11.swin3.attn.qkv.bias", "layers.11.swin3.attn.proj.weight", "layers.11.swin3.attn.proj.bias", "layers.11.swin3.norm2.weight", "layers.11.swin3.norm2.bias", "layers.11.swin3.mlp.fc1.weight", "layers.11.swin3.mlp.fc1.bias", "layers.11.swin3.mlp.fc2.weight", "layers.11.swin3.mlp.fc2.bias", "layers.11.adjust3.weight", "layers.11.adjust3.bias", "layers.11.swin4.attn_mask", "layers.11.swin4.norm1.weight", "layers.11.swin4.norm1.bias", "layers.11.swin4.attn.relative_position_bias_table", "layers.11.swin4.attn.relative_position_index", "layers.11.swin4.attn.qkv.weight", "layers.11.swin4.attn.qkv.bias", "layers.11.swin4.attn.proj.weight", "layers.11.swin4.attn.proj.bias", "layers.11.swin4.norm2.weight", "layers.11.swin4.norm2.bias", "layers.11.swin4.mlp.fc1.weight", "layers.11.swin4.mlp.fc1.bias", "layers.11.swin4.mlp.fc2.weight", "layers.11.swin4.mlp.fc2.bias", "layers.11.adjust4.weight", "layers.11.adjust4.bias", "layers.11.swin5.norm1.weight", "layers.11.swin5.norm1.bias", "layers.11.swin5.attn.relative_position_bias_table", "layers.11.swin5.attn.relative_position_index", "layers.11.swin5.attn.qkv.weight", "layers.11.swin5.attn.qkv.bias", "layers.11.swin5.attn.proj.weight", "layers.11.swin5.attn.proj.bias", "layers.11.swin5.norm2.weight", "layers.11.swin5.norm2.bias", "layers.11.swin5.mlp.fc1.weight", "layers.11.swin5.mlp.fc1.bias", "layers.11.swin5.mlp.fc2.weight", "layers.11.swin5.mlp.fc2.bias", "layers.11.adjust5.weight", "layers.11.adjust5.bias".

Please tell me how can i solve this issue?

gtentillier commented 3 months ago

Hello guys, I had the same issue in a similar model, and solved it by checking that the weights i tried to load for the model correspond to the architecture of the model in the inference.py file. This error means that the layers expected by the weights differ from the architecture of the model initialized in the inference.py file. According to the default argument model_path in inference.py file, model_weights like DRTCT-L.pth are expected, which ones are you trying to use ? I would have tried to use these ones : https://drive.google.com/file/d/1bVxvA6QFbne2se0CQJ-jyHFy94UOi3h5/view If they don't work, we either need to find other weights corresponding to the architecture, or changing the architecture in the model initialization in the inference.py file. Let me know if it worked :)

happy-liuzhixuan commented 3 months ago

Hello, how can I solve the above problem, I am also having the same problem with net_g_latest.pth and DRCT-L_X4.pth, using net_g_latest (MSEModel).pth tests normally