microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.84k stars 2.94k forks source link

TensorRT Execution Provider build fail while using TensorRT successfully create engine file #17312

Open namtr92 opened 1 year ago

namtr92 commented 1 year ago

Describe the issue

Hi, I am using ONNX Runtime with TensorRT Execution Provider for a quantized model (YOLO-NAS). While TensorRT cli (trtexec.exe) successfully build the engine from onnx model, the ONNX Runtime with TensorRT Execution Provider cannot build the engine file. Here is the output of TensorRT Execution Provider:

To reproduce

ort_session = onnxruntime.InferenceSession( "model.onnx", providers=['TensorrtExecutionProvider'], provider_options=[{'device_id': '0', 'trt_int8_enable':True}] )

link to download model: [https://drive.google.com/file/d/1ZxT2wCU0bIjYrKmhDuFgKRew1kR9hbRd/view?usp=sharing]

Urgency

Very Urgent

Platform

Windows

OS Version

11

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

TensorRT

Execution Provider Library Version

CUDA 11.6, CUDNN 8.9.0, TENSORRT 8.6.1

jywu-msft commented 1 year ago

@yf711 can you help take a look?

yf711 commented 1 year ago

The issue is confirmed and I can repro it in local env.

I also tried trtexec --int8 --onnx=model.onnx --saveEngine=model.trt and it could pass. Will check why this chooseHigherPrecision was executed when trt_int8_enable was selected

yf711 commented 1 year ago

Hi @namtr92 Could you try this workaround by disabling ORT graph optimization while initiating session?

import onnxruntime
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
ort_session = onnxruntime.InferenceSession(
    "model.onnx",
    sess_options,
    providers=['TensorrtExecutionProvider'],
    provider_options=[{'device_id': '0',
                       'trt_int8_enable': True,
                       'trt_engine_cache_enable': True
                      }]

According to the key context shared by Nvidia, there might be overlaps between ORT's graph optimization and TRT's QDQ optimization. I will check if QDQ graph optimization could be skipped when TRT EP is selected

namtr92 commented 1 year ago

Hi @namtr92 Could you try this workaround by disabling ORT graph optimization while initiating session?

import onnxruntime
sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
ort_session = onnxruntime.InferenceSession(
    "model.onnx",
    sess_options,
    providers=['TensorrtExecutionProvider'],
    provider_options=[{'device_id': '0',
                       'trt_int8_enable': True,
                       'trt_engine_cache_enable': True
                      }]

According to the key context shared by Nvidia, there might be overlaps between ORT's graph optimization and TRT's QDQ optimization. I will check if QDQ graph optimization could be skipped when TRT EP is selected

Thank for your help, now I could use TensorRT EP normally !