ThanatosShinji / onnx-tool

A parser, editor and profiler tool for ONNX models.
https://pypi.org/project/onnx-tool/
MIT License
399 stars 52 forks source link

Tensor dtype error #72

Open sreedattaSanjay opened 8 months ago

sreedattaSanjay commented 8 months ago

Hello I'm trying to use the bert_mha_layernorm_fuse() function in the examples.py

The problem is when I'm trying to infer the fused model I'm getting a dtype mismatch error After a little debugging, I found out that in Tensor() class if the instance of tensor is str you are automatically assigning dtype as numpy.float32

from .node import Node if isinstance(t, str): self.name = t self.proto = None self.shape = [] self.numpy = None self.type = DYNAMIC_TENSOR if t != '' else STATIC_TENSOR self.dtype = numpy.float32

Due to this output tensors that are Dynamic and do not have any value are loaded as float32 tensors

In my case, the output tensor dtype should be int but it is float32

onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from bertsquad_mha_layernorm.onnx failed:Type Error: Type (tensor(int32)) of output arg (bert/encoder/Shape:0) of node (bert/encoder/Shape) does not match expected type (tensor(int64)).

Here is the original bertsquad-12 model orig_model_bert

And here is the fused model bert_MHA_Layernorm_model

Can you please let me know how to resolve this issue? Is there a function that lets us save model with proper datatypes

sreedattaSanjay commented 8 months ago

Is there any script in your local testing that lets us infer the models you have created in the examples.py? If yes can you please let us use it Thank you

ThanatosShinji commented 8 months ago

The link from the README is out of date, you can use this link to download the BERT model