cszn / SCUNet

Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis (Machine Intelligence Research 2023)
https://link.springer.com/article/10.1007/s11633-023-1466-0
Apache License 2.0
616 stars 62 forks source link

Error in converting to ONNX model #25

Open ayazhassan opened 8 months ago

ayazhassan commented 8 months ago

I am getting the following error, while trying to convert the pre-trained model to ONNX model. Can you please look into it and let me know that the pre-trained weights were generated using the current updated model? Conversion code is provided after the error.

Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Block Initial Type: W, drop_path_rate:0.000000 Block Initial Type: SW, drop_path_rate:0.000000 Traceback (most recent call last): File "/home/ayaz_khan/SCUNet/onnx.py", line 2, in import torch.onnx File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/init.py", line 57, in from ._internal.onnxruntime import ( File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/onnx/_internal/onnxruntime.py", line 34, in import onnx File "/home/ayaz_khan/SCUNet/onnx.py", line 25, in convert_to_onnx(model_path, onnx_path) File "/home/ayaz_khan/SCUNet/onnx.py", line 8, in convert_to_onnx model.load_state_dict(torch.load(model_path)) File "/home/ayaz_khan/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SCUNet: Missing key(s) in state_dict: "m_down1.2.weight", "m_down2.2.weight", "m_down3.2.weight". Unexpected key(s) in state_dict: "m_down1.3.trans_block.ln1.weight", "m_down1.3.trans_block.ln1.bias", "m_down1.3.trans_block.msa.relative_position_params", "m_down1.3.trans_block.msa.embedding_layer.weight", "m_down1.3.trans_block.msa.embedding_layer.bias", "m_down1.3.trans_block.msa.linear.weight", "m_down1.3.trans_block.msa.linear.bias", "m_down1.3.trans_block.ln2.weight", "m_down1.3.trans_block.ln2.bias", "m_down1.3.trans_block.mlp.0.weight", "m_down1.3.trans_block.mlp.0.bias", "m_down1.3.trans_block.mlp.2.weight", "m_down1.3.trans_block.mlp.2.bias", "m_down1.3.conv1_1.weight", "m_down1.3.conv1_1.bias", "m_down1.3.conv1_2.weight", "m_down1.3.conv1_2.bias", "m_down1.3.conv_block.0.weight", "m_down1.3.conv_block.2.weight", "m_down1.4.weight", "m_down1.2.trans_block.ln1.weight", "m_down1.2.trans_block.ln1.bias", "m_down1.2.trans_block.msa.relative_position_params", "m_down1.2.trans_block.msa.embedding_layer.weight", "m_down1.2.trans_block.msa.embedding_layer.bias", "m_down1.2.trans_block.msa.linear.weight", "m_down1.2.trans_block.msa.linear.bias", "m_down1.2.trans_block.ln2.weight", "m_down1.2.trans_block.ln2.bias", "m_down1.2.trans_block.mlp.0.weight", "m_down1.2.trans_block.mlp.0.bias", "m_down1.2.trans_block.mlp.2.weight", "m_down1.2.trans_block.mlp.2.bias", "m_down1.2.conv1_1.weight", "m_down1.2.conv1_1.bias", "m_down1.2.conv1_2.weight", "m_down1.2.conv1_2.bias", "m_down1.2.conv_block.0.weight", "m_down1.2.conv_block.2.weight", "m_down2.3.trans_block.ln1.weight", "m_down2.3.trans_block.ln1.bias", "m_down2.3.trans_block.msa.relative_position_params", "m_down2.3.trans_block.msa.embedding_layer.weight", "m_down2.3.trans_block.msa.embedding_layer.bias", "m_down2.3.trans_block.msa.linear.weight", "m_down2.3.trans_block.msa.linear.bias", "m_down2.3.trans_block.ln2.weight", "m_down2.3.trans_block.ln2.bias", "m_down2.3.trans_block.mlp.0.weight", "m_down2.3.trans_block.mlp.0.bias", "m_down2.3.trans_block.mlp.2.weight", "m_down2.3.trans_block.mlp.2.bias", "m_down2.3.conv1_1.weight", "m_down2.3.conv1_1.bias", "m_down2.3.conv1_2.weight", "m_down2.3.conv1_2.bias", "m_down2.3.conv_block.0.weight", "m_down2.3.conv_block.2.weight", "m_down2.4.weight", "m_down2.2.trans_block.ln1.weight", "m_down2.2.trans_block.ln1.bias", "m_down2.2.trans_block.msa.relative_position_params", "m_down2.2.trans_block.msa.embedding_layer.weight", "m_down2.2.trans_block.msa.embedding_layer.bias", "m_down2.2.trans_block.msa.linear.weight", "m_down2.2.trans_block.msa.linear.bias", "m_down2.2.trans_block.ln2.weight", "m_down2.2.trans_block.ln2.bias", "m_down2.2.trans_block.mlp.0.weight", "m_down2.2.trans_block.mlp.0.bias", "m_down2.2.trans_block.mlp.2.weight", "m_down2.2.trans_block.mlp.2.bias", "m_down2.2.conv1_1.weight", "m_down2.2.conv1_1.bias", "m_down2.2.conv1_2.weight", "m_down2.2.conv1_2.bias", "m_down2.2.conv_block.0.weight", "m_down2.2.conv_block.2.weight", "m_down3.3.trans_block.ln1.weight", "m_down3.3.trans_block.ln1.bias", "m_down3.3.trans_block.msa.relative_position_params", "m_down3.3.trans_block.msa.embedding_layer.weight", "m_down3.3.trans_block.msa.embedding_layer.bias", "m_down3.3.trans_block.msa.linear.weight", "m_down3.3.trans_block.msa.linear.bias", "m_down3.3.trans_block.ln2.weight", "m_down3.3.trans_block.ln2.bias", "m_down3.3.trans_block.mlp.0.weight", "m_down3.3.trans_block.mlp.0.bias", "m_down3.3.trans_block.mlp.2.weight", "m_down3.3.trans_block.mlp.2.bias", "m_down3.3.conv1_1.weight", "m_down3.3.conv1_1.bias", "m_down3.3.conv1_2.weight", "m_down3.3.conv1_2.bias", "m_down3.3.conv_block.0.weight", "m_down3.3.conv_block.2.weight", "m_down3.4.weight", "m_down3.2.trans_block.ln1.weight", "m_down3.2.trans_block.ln1.bias", "m_down3.2.trans_block.msa.relative_position_params", "m_down3.2.trans_block.msa.embedding_layer.weight", "m_down3.2.trans_block.msa.embedding_layer.bias", "m_down3.2.trans_block.msa.linear.weight", "m_down3.2.trans_block.msa.linear.bias", "m_down3.2.trans_block.ln2.weight", "m_down3.2.trans_block.ln2.bias", "m_down3.2.trans_block.mlp.0.weight", "m_down3.2.trans_block.mlp.0.bias", "m_down3.2.trans_block.mlp.2.weight", "m_down3.2.trans_block.mlp.2.bias", "m_down3.2.conv1_1.weight", "m_down3.2.conv1_1.bias", "m_down3.2.conv1_2.weight", "m_down3.2.conv1_2.bias", "m_down3.2.conv_block.0.weight", "m_down3.2.conv_block.2.weight", "m_body.2.trans_block.ln1.weight", "m_body.2.trans_block.ln1.bias", "m_body.2.trans_block.msa.relative_position_params", "m_body.2.trans_block.msa.embedding_layer.weight", "m_body.2.trans_block.msa.embedding_layer.bias", "m_body.2.trans_block.msa.linear.weight", "m_body.2.trans_block.msa.linear.bias", "m_body.2.trans_block.ln2.weight", "m_body.2.trans_block.ln2.bias", "m_body.2.trans_block.mlp.0.weight", "m_body.2.trans_block.mlp.0.bias", "m_body.2.trans_block.mlp.2.weight", "m_body.2.trans_block.mlp.2.bias", "m_body.2.conv1_1.weight", "m_body.2.conv1_1.bias", "m_body.2.conv1_2.weight", "m_body.2.conv1_2.bias", "m_body.2.conv_block.0.weight", "m_body.2.conv_block.2.weight", "m_body.3.trans_block.ln1.weight", "m_body.3.trans_block.ln1.bias", "m_body.3.trans_block.msa.relative_position_params", "m_body.3.trans_block.msa.embedding_layer.weight", "m_body.3.trans_block.msa.embedding_layer.bias", "m_body.3.trans_block.msa.linear.weight", "m_body.3.trans_block.msa.linear.bias", "m_body.3.trans_block.ln2.weight", "m_body.3.trans_block.ln2.bias", "m_body.3.trans_block.mlp.0.weight", "m_body.3.trans_block.mlp.0.bias", "m_body.3.trans_block.mlp.2.weight", "m_body.3.trans_block.mlp.2.bias", "m_body.3.conv1_1.weight", "m_body.3.conv1_1.bias", "m_body.3.conv1_2.weight", "m_body.3.conv1_2.bias", "m_body.3.conv_block.0.weight", "m_body.3.conv_block.2.weight", "m_up3.3.trans_block.ln1.weight", "m_up3.3.trans_block.ln1.bias", "m_up3.3.trans_block.msa.relative_position_params", "m_up3.3.trans_block.msa.embedding_layer.weight", "m_up3.3.trans_block.msa.embedding_layer.bias", "m_up3.3.trans_block.msa.linear.weight", "m_up3.3.trans_block.msa.linear.bias", "m_up3.3.trans_block.ln2.weight", "m_up3.3.trans_block.ln2.bias", "m_up3.3.trans_block.mlp.0.weight", "m_up3.3.trans_block.mlp.0.bias", "m_up3.3.trans_block.mlp.2.weight", "m_up3.3.trans_block.mlp.2.bias", "m_up3.3.conv1_1.weight", "m_up3.3.conv1_1.bias", "m_up3.3.conv1_2.weight", "m_up3.3.conv1_2.bias", "m_up3.3.conv_block.0.weight", "m_up3.3.conv_block.2.weight", "m_up3.4.trans_block.ln1.weight", "m_up3.4.trans_block.ln1.bias", "m_up3.4.trans_block.msa.relative_position_params", "m_up3.4.trans_block.msa.embedding_layer.weight", "m_up3.4.trans_block.msa.embedding_layer.bias", "m_up3.4.trans_block.msa.linear.weight", "m_up3.4.trans_block.msa.linear.bias", "m_up3.4.trans_block.ln2.weight", "m_up3.4.trans_block.ln2.bias", "m_up3.4.trans_block.mlp.0.weight", "m_up3.4.trans_block.mlp.0.bias", "m_up3.4.trans_block.mlp.2.weight", "m_up3.4.trans_block.mlp.2.bias", "m_up3.4.conv1_1.weight", "m_up3.4.conv1_1.bias", "m_up3.4.conv1_2.weight", "m_up3.4.conv1_2.bias", "m_up3.4.conv_block.0.weight", "m_up3.4.conv_block.2.weight", "m_up2.3.trans_block.ln1.weight", "m_up2.3.trans_block.ln1.bias", "m_up2.3.trans_block.msa.relative_position_params", "m_up2.3.trans_block.msa.embedding_layer.weight", "m_up2.3.trans_block.msa.embedding_layer.bias", "m_up2.3.trans_block.msa.linear.weight", "m_up2.3.trans_block.msa.linear.bias", "m_up2.3.trans_block.ln2.weight", "m_up2.3.trans_block.ln2.bias", "m_up2.3.trans_block.mlp.0.weight", "m_up2.3.trans_block.mlp.0.bias", "m_up2.3.trans_block.mlp.2.weight", "m_up2.3.trans_block.mlp.2.bias", "m_up2.3.conv1_1.weight", "m_up2.3.conv1_1.bias", "m_up2.3.conv1_2.weight", "m_up2.3.conv1_2.bias", "m_up2.3.conv_block.0.weight", "m_up2.3.conv_block.2.weight", "m_up2.4.trans_block.ln1.weight", "m_up2.4.trans_block.ln1.bias", "m_up2.4.trans_block.msa.relative_position_params", "m_up2.4.trans_block.msa.embedding_layer.weight", "m_up2.4.trans_block.msa.embedding_layer.bias", "m_up2.4.trans_block.msa.linear.weight", "m_up2.4.trans_block.msa.linear.bias", "m_up2.4.trans_block.ln2.weight", "m_up2.4.trans_block.ln2.bias", "m_up2.4.trans_block.mlp.0.weight", "m_up2.4.trans_block.mlp.0.bias", "m_up2.4.trans_block.mlp.2.weight", "m_up2.4.trans_block.mlp.2.bias", "m_up2.4.conv1_1.weight", "m_up2.4.conv1_1.bias", "m_up2.4.conv1_2.weight", "m_up2.4.conv1_2.bias", "m_up2.4.conv_block.0.weight", "m_up2.4.conv_block.2.weight", "m_up1.3.trans_block.ln1.weight", "m_up1.3.trans_block.ln1.bias", "m_up1.3.trans_block.msa.relative_position_params", "m_up1.3.trans_block.msa.embedding_layer.weight", "m_up1.3.trans_block.msa.embedding_layer.bias", "m_up1.3.trans_block.msa.linear.weight", "m_up1.3.trans_block.msa.linear.bias", "m_up1.3.trans_block.ln2.weight", "m_up1.3.trans_block.ln2.bias", "m_up1.3.trans_block.mlp.0.weight", "m_up1.3.trans_block.mlp.0.bias", "m_up1.3.trans_block.mlp.2.weight", "m_up1.3.trans_block.mlp.2.bias", "m_up1.3.conv1_1.weight", "m_up1.3.conv1_1.bias", "m_up1.3.conv1_2.weight", "m_up1.3.conv1_2.bias", "m_up1.3.conv_block.0.weight", "m_up1.3.conv_block.2.weight", "m_up1.4.trans_block.ln1.weight", "m_up1.4.trans_block.ln1.bias", "m_up1.4.trans_block.msa.relative_position_params", "m_up1.4.trans_block.msa.embedding_layer.weight", "m_up1.4.trans_block.msa.embedding_layer.bias", "m_up1.4.trans_block.msa.linear.weight", "m_up1.4.trans_block.msa.linear.bias", "m_up1.4.trans_block.ln2.weight", "m_up1.4.trans_block.ln2.bias", "m_up1.4.trans_block.mlp.0.weight", "m_up1.4.trans_block.mlp.0.bias", "m_up1.4.trans_block.mlp.2.weight", "m_up1.4.trans_block.mlp.2.bias", "m_up1.4.conv1_1.weight", "m_up1.4.conv1_1.bias", "m_up1.4.conv1_2.weight", "m_up1.4.conv1_2.bias", "m_up1.4.conv_block.0.weight", "m_up1.4.conv_block.2.weight".

import torch
import torch.onnx
from models.network_scunet import SCUNet  # Assuming this is the SCUNet model definition

def convert_to_onnx(model_path, onnx_path, input_shape=(1, 3, 256, 256)):
    # Load the pre-trained PyTorch model
    model = SCUNet()
    model.load_state_dict(torch.load(model_path))

    # Set the model to evaluation mode
    model.eval()

    # Define dummy input data
    dummy_input = torch.randn(input_shape)

    # Convert the model to ONNX format
    torch.onnx.export(model, dummy_input, onnx_path, verbose=True)

    print(f"Model converted to ONNX format and saved as {onnx_path}")

# Paths
model_path = "model_zoo/scunet_color_real_psnr.pth"
onnx_path = "./scunet_color_real_gan.onnx"

convert_to_onnx(model_path, onnx_path)
instant-high commented 2 months ago

I've converted the models toONNX (dynamic axes for input size). Models are working but take only square input images Divisible by 128

jumpxiu commented 1 week ago

I managed to output the model in onnx form, here is my code:


net = SCUNet()
onnx_path = "onnx_model_name.onnx"
torch.onnx.export(net, torch.randn((2, 3, 64, 64)), onnx_path)

But remember to run the model's classes and methods all through before doing so