daquexian / onnx-simplifier

Simplify your onnx model
Apache License 2.0
3.67k stars 377 forks source link

[BUG] after simplifing the model swin_tiny_patch4_window7_224 which was created by timm, onnxruntime got errors #307

Open lichun-wang opened 10 months ago

lichun-wang commented 10 months ago

Describe the bug I use timm to create 'swin_tiny_patch4_window7_224', and then I use torch.onnx.export to export [ swin_tiny_patch4_window7_224.onnx ] model. After that , I use onnxruntime to run the model , It's OK.

But After I use onnx_simplifier to simplify swin_tiny.onnx, Onnxruntime cannot run ,and got errors like this:

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from swin_tiny_patch4_window7_224.onnx failed:Node (/layers/layers.0/blocks/blocks.0/Reshape_6) Op (Reshape) [ShapeInferenceError] Dimension could not be inferred: incompatible shapes

My Code

import timm
import onnxruntime as ort
import torch

model = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=1000).eval()

 # input

dummy_input = torch.randn(*(1, 3, 224, 224), device='cpu')

onnx_path = 'swin_tiny_patch4_window7_224.onnx'

torch.onnx.export(model,
                dummy_input,
                onnx_path,
                verbose=False, 
                opset_version=17,
                do_constant_folding=True,  
                keep_initializers_as_inputs=True, 
                input_names=["input"],      
                output_names=["output"],  
                dynamic_axes={"input":{0:"batch_size"},"output":{0:"batch_size"}}
                )

import onnx
from onnxsim import onnx_simplifier
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
model_simp, check = onnx_simplifier.simplify(onnx_model, check_n = 0)
onnx.save(model_simp, onnx_path)
print(f"simplify over : {onnx_path}  ")

ort_sess = ort.InferenceSession(onnx_path)
outputs = ort_sess.run(None, {'input': dummy_input.numpy().astype('float32')})
Collonville commented 2 months ago

Have a same error 😢