microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.67k stars 2.93k forks source link

`onnx.checker.check_model` raises `Bad node spec` for custom nodes created from ORT `optimize_model` #14768

Open fxmarty opened 1 year ago

fxmarty commented 1 year ago

Describe the issue

After using onnxruntime.transformers.optimizer.optimize_model, the ONNX model validity checker onnx.checker.check_model raises Context: Bad node spec for node. Name: LayerNorm_0 OpType: LayerNormalization.

To reproduce

Given the ONNX model decoder_model.onnx (here):

import onnxruntime as ort
from transformers import AutoConfig
from onnxruntime.transformers.fusion_options import FusionOptions
from onnxruntime.transformers.optimizer import optimize_model

config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-gpt2")
num_attention_heads = config.n_head
hidden_size = config.hidden_size

session_options = ort.SessionOptions()

path = "/tmp/gpt2_onnx_default/decoder_model.onnx"
optimized_path = "/tmp/gpt2_onnx_default/decoder_model_optimized.onnx"
optimization_options = FusionOptions("gpt2")

m = optimize_model(
    path,
    model_type="gpt2",
    num_heads=num_attention_heads,
    hidden_size=hidden_size,
    opt_level=0,
    optimization_options=optimization_options,
    use_gpu=False,
)

m.save_model_to_file(optimized_path, use_external_data_format=False)

And then with onnx==1.13.0:

import onnx
optimized_path = "/tmp/gpt2_onnx_default/decoder_model_optimized.onnx"
model = onnx.load(optimized_path)
onnx.checker.check_model(model)

that raises:

Traceback (most recent call last):
  File "/home/fxmarty/test_merge.py", line 64, in <module>
    onnx.checker.check_model(model)
  File "/home/fxmarty/anaconda3/envs/hf-inf/lib/python3.9/site-packages/onnx/checker.py", line 119, in check_model
    C.check_model(protobuf_string, full_check)
onnx.onnx_cpp2py_export.checker.ValidationError: No Op registered for LayerNormalization with domain_version of 13

==> Context: Bad node spec for node. Name: LayerNorm_0 OpType: LayerNormalization

This node was indeed created by the optimizer: image

I suspect an opset import is missing? Or is this expected?

Urgency

medium

Platform

Linux

OS Version

Linux 5.15.0-56-generic #62-Ubuntu SMP Tue Nov 22 19:54:14 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

fxmarty commented 1 year ago

Actually, I think it is a bug in ONNX Runtime. The original model opset is 13. The optimization introduces a LayerNormalization that requires opset >=17. https://github.com/onnx/onnx/blob/main/docs/Operators.md#layernormalization . Despite this, the optimized ONNX model still imports ai.onnx v13.

ONNX Runtime optimizer should probably raise an error that the produced ONNX model will be invalid due to the opset mismatch.

Edit: or is the LayerNormalization in the graph from com.microsoft?

So if the LayerNormalization is expected to come from com.microsoft, could it be a name collision due to https://github.com/microsoft/onnxruntime/blob/df45b12fdf8fbeb68c314b2475eabbc7c2ea3af8/onnxruntime/contrib_ops/cpu/layer_norm.cc#L13? onnx.checker.check_model should be able to deal with imported domains: https://github.com/onnx/onnx/issues/3544#issuecomment-867845074

fxmarty commented 1 year ago

Actually getting as well No Op registered for MemcpyFromHost with domain_version of 17 as well when using convert_float_to_float16 optimization + opt_level=2, and this op is not in https://github.com/onnx/onnx/blob/main/docs/Operators.md

tianleiwu commented 1 year ago

Before LayerNormalization added to ONNX specification, it has been supported in ONNX Runtime. However, it was added by mistake to onnx domain (it shall be added to com.microsoft domain at that time) before it became official onnx operator. Later it was kept in onnx domain for backward compatible.

Even though onnx checker raises error, the model can run in ONNX Runtime.

@jcwchen, is it possible to add an option to ignore some operators in onnx.checker?

Here is a list of operators with such problems: https://github.com/microsoft/onnxruntime/blob/b3b9be19b140ba073f0b1c44bfc3ffc011768cbc/tools/ci_build/op_registration_validator.py#L18-L30

jcwchen commented 1 year ago

@jcwchen, is it possible to add an option to ignore some operators in onnx.checker?

From ONNX perspective, this model has some operators which are not included in official spec. That is to say, it cannot run normally by ONNX utilities such as shape_inference, version_converter, or even other ONNX related tools (except ONNX Runtime). ONNX checker's job is to validate whether such an ONNX model follows ONNX spec and so it seems fine to report an error here. I don't see a strong use case for enabling ONNX checker to ignore some specified operators.

cc @gramalingam @liqunfu in case you have other thoughts on this topic.

tianleiwu commented 1 year ago

@jcwchen, it is fine that we keep current onnx checker behavior.

I would suggest users to run ONNX checker on raw onnx model only. For optimized model, it is better to create an inference session to test whether the model is good for an execution provider (since some execution provider only supports a limited set of operators).

fxmarty commented 1 year ago

Thank you for your answers! So I understand onnx.checker.check_model can only only validate operators in the onnx domain, and not from imported domains (like I guess LayerNormalization, MemcpyFromHost are?), contrary to what I understood in https://github.com/onnx/onnx/issues/3544#issuecomment-867845074. I'll use InferenceSession loading to check validity for now.

Feel free to close!

barathbheeman commented 1 year ago

This is a related question. I'm getting the same error and I found this out while trying to convert the onnx model to tensor-rt engine using the trtexec command.

I get the following output:

[03/15/2023-15:53:28] [I] [TRT] No importer registered for op: LayerNormalization. Attempting to import as plugin.
[03/15/2023-15:53:28] [I] [TRT] Searching for plugin: LayerNormalization, plugin_version: 1, plugin_namespace: 
[03/15/2023-15:53:28] [E] [TRT] ModelImporter.cpp:726: While parsing node number 9 [LayerNormalization -> "/distilbert/embeddings/LayerNorm/Add_1_output_0"]:
[03/15/2023-15:53:28] [E] [TRT] ModelImporter.cpp:727: --- Begin node ---
[03/15/2023-15:53:28] [E] [TRT] ModelImporter.cpp:728: input: "/distilbert/embeddings/Add_output_0"
input: "distilbert.embeddings.LayerNorm.weight"
input: "distilbert.embeddings.LayerNorm.bias"
output: "/distilbert/embeddings/LayerNorm/Add_1_output_0"
name: "LayerNormalization"
op_type: "LayerNormalization"
attribute {
  name: "stash_type"
  i: 1
  type: INT
}
attribute {
  name: "axis"
  i: -1
  type: INT
}
attribute {
  name: "epsilon"
  f: 1e-12
  type: FLOAT
}
doc_string: "fused LayerNorm subgraphs "

[03/15/2023-15:53:28] [E] [TRT] ModelImporter.cpp:729: --- End node ---
[03/15/2023-15:53:28] [E] [TRT] ModelImporter.cpp:732: ERROR: builtin_op_importers.cpp:5428 In function importFallbackPluginImporter:
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[03/15/2023-15:53:28] [E] Failed to parse onnx file
[03/15/2023-15:53:28] [I] Finish parsing network model
[03/15/2023-15:53:28] [E] Parsing model failed
[03/15/2023-15:53:28] [E] Failed to create engine from model or file.
[03/15/2023-15:53:28] [E] Engine set up failed

Is there a workaround for this?

tianleiwu commented 1 year ago

@barathbheeman, for TRT you need use the original (not optimized) onnx model exported from Pytorch to build TRT engine. There is no LayerNormalization plugin in TRT.