Open Hap-Zhang opened 3 years ago
@Hap-Zhang, could you try model optimization like
...
from onnxruntime.transformers.optimizer import optimize_model
onnx_model = optimize_model(onnx_model_path)
new_onnx_path = "ner_opt_model.onnx"
onnx_model.save_model_to_file(new_onnx_path)
options = SessionOptions()
session = InferenceSession(new_onnx_path, options, providers=['CPUExecutionProvider'])
...
Check whether there is "Attention" nodes in the new onnx file. The performance might be impacted if "Attention" is not fused.
For more information, please refer to the notebook: https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_CPU.ipynb
@tianleiwu Thank you for your quickly reply. I try the model optimization as you said, and the inference time from 505ms to 359ms. However, the inference time of original pytorch model reduces to 337ms by export MKL_CBWR=AVX2, but nothing changed in onnx model whether using AVX2 or not. The backend of onnx is MKL or something else?
The CPU execution provider does not use MKL. It uses MLAS. As far as I know, MLAS could leverage AVX2 in some situation (Like Windows, x64 and quantized model): https://github.com/microsoft/onnxruntime/blob/0cc29095733cecc55efe0b8e0d8ff7cd2a9e427a/cmake/onnxruntime_mlas.cmake#L113-L114
I'm using ResNet & CPU execution provider & C++. ONNX runtime's default setting is 6x slower than PyTorch.
options.AddConfigEntry(kOrtSessionOptionsConfigSetDenormalAsZero, "1");
solved that.
This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.
Hi,all
I want to accelerate the inference of NER model whose backend is bert. When i export onnx model with torch.onnx.export, i find that the inference of onnx model(505ms) is slower than original pytorch model(337ms), i'm not sure what went wrong?
The version of Pytorch is 1.9.0. The version of transformer is 4.6.1
The code lists as below: