microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.1k stars 2.84k forks source link

[Feature Request] TensorRT custom engine Plans #13559

Open contentis opened 1 year ago

contentis commented 1 year ago

Describe the feature request

It would be great to have the option to provide pre-optimised TensorRT engine plans to ORT.

Describe scenario use case

Using TensorRT in standalone, e.g. trtexec, allows for much more precise control over engine compilation.

Currently, only a very small amount of TRT features is accessible through the ORT API, but most importantly, there is no way to ship precompiled engines with an application; Engine compilations need to be done on the user machine, resulting in a bad user experience. Also, how (TRT) caching is handled internally is not very transparent, in my opinion, and might lead to unnecessary engine recompilations.

jywu-msft commented 1 year ago

Thanks for providing this feedback and filing the feature request. It is possible to ship engine files with the application by pre-generating them and then specifying the path to the engine files at runtime. I admit it is a bit ad hoc, and it would be nice to support a cleaner offline engine creation workflow. Can you provide some more details/examples about the finer grain control you want to see (that is available via trtexec) for the TRT cache handling, we weren't trying to make it a black box. We will strive to make it more transparent. Would some documentation help? Any other suggestions to help make it more transparent are welcome. +@stevenlix

contentis commented 1 year ago

One feature especially that trtexec has is the option to disable resource tactics. E.g. CUDNN This is something that is done more often to reduce the binary size (cuDNN +gb) and only compromises little performance. With TRT you also might not want to first get an ONNX model and then compile the engine out of it but directly go from python code to engine / define your graph manually.

Generally, I think keeping engine compilation and inference (ORT) separate is a cleaner approach. Currently, there is not really an incentive for people to use TRT + ORT instead of using the TRT runtime API in standalone.

fxmarty commented 1 year ago

strong +1 @contentis

With trtexec I can pass for example --minShapes=input_ids:1x10,attention_mask:1x10 --optShapes=input_ids:1x12,attention_mask:1x12 --maxShapes=input_ids:1x19,attention_mask:1x19 that is very quick to build the engine, while building the engine with session.run() is done each time and very slow, so for example for autogenerative tasks building the engine takes forever cc @stevenlix @jywu-msft .

Related https://github.com/huggingface/optimum/issues/606 https://github.com/huggingface/optimum/issues/605

For example with gpt2 generation (the took: gives the time taken by session.run() in seconds), without reusing past key/values:

onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 10)
attention_mask (1, 10)
took: 23.769636154174805
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 11)
attention_mask (1, 11)
took: 48.19532656669617
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 12)
attention_mask (1, 12)
took: 49.167895555496216
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 13)
attention_mask (1, 13)
took: 49.58247900009155
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 14)
attention_mask (1, 14)
took: 47.33034324645996
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 15)
attention_mask (1, 15)
took: 46.33792042732239
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407, 1654]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 16)
attention_mask (1, 16)
took: 48.12443137168884
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407, 1654,  611]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 17)
attention_mask (1, 17)
took: 69.13102579116821
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407, 1654,  611,  345]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 18)
attention_mask (1, 18)
took: 70.63710856437683
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407, 1654,  611,  345,  821]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 19)
attention_mask (1, 19)
took: 68.33463191986084
------
Replace me by any text you'd like.

I'm not sure if you're aware

And once the engine is built, we get:

onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 10)
attention_mask (1, 10)
took: 0.003805875778198242
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 11)
attention_mask (1, 11)
took: 0.004013776779174805
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 12)
attention_mask (1, 12)
took: 0.0039560794830322266
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 13)
attention_mask (1, 13)
took: 0.003964662551879883
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 14)
attention_mask (1, 14)
took: 0.004000663757324219
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 15)
attention_mask (1, 15)
took: 0.0040628910064697266
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407, 1654]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 16)
attention_mask (1, 16)
took: 0.004026651382446289
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407, 1654,  611]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 17)
attention_mask (1, 17)
took: 0.004216909408569336
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407, 1654,  611,  345]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 18)
attention_mask (1, 18)
took: 0.0042133331298828125
------
onnx_inputs: {'input_ids': array([[3041, 5372,  502,  416,  597, 2420,  345, 1549,  588,   13,  198,
         198,   40, 1101,  407, 1654,  611,  345,  821]]), 'attention_mask': array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
input_ids (1, 19)
attention_mask (1, 19)
took: 0.004229068756103516
------
Replace me by any text you'd like.

I'm not sure if you're aware
fxmarty commented 1 year ago

A suggestion: TensorrtExecutionProvider generates a .engine file for each subgraph executed on the provider. The naming is automatic: https://github.com/microsoft/onnxruntime/blob/613920d6c5f53a8e5e647c5f1dcdecb0a8beef31/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L744-L750

And the provider option trt_engine_cache_path expects a directory. Therefore, if I have previously built the engine with trtexec, I can not pass:

provider_options = {
    "trt_engine_cache_enable": True,
    "trt_engine_cache_path": "/path/to/gpt2_onnx/gpt2_trt.engine"
}

because ONNX Runtime assumes that the graph may be broken in several pieces, and won't recognize my .engine file as the right ones (something like TensorrtExecutionProvider_TRTKernel_graph_torch_jit_11146379831831262227_1_0.engine is expected for the caching to work).

gedoensmax commented 1 year ago

I second @fxmarty furthermore providing an explicit engine path instead of a directory could enable following cases:

  1. File existst: Engine is loaded without any checks and also does not require deserialization of the ONNX file and optimization which can be time consuming during setup:
  2. Files does not exist: Usual behavior of creating the engine from an ONNX file but instead of hashing it to construct the name, the given path is used.
fxmarty commented 1 year ago

Related: https://github.com/microsoft/onnxruntime/issues/13851

baoachun commented 1 year ago

Agreed, I have the same need now.

jywu-msft commented 1 year ago

I second @fxmarty furthermore providing an explicit engine path instead of a directory could enable following cases:

  1. File existst: Engine is loaded without any checks and also does not require deserialization of the ONNX file and optimization which can be time consuming during setup:
  2. Files does not exist: Usual behavior of creating the engine from an ONNX file but instead of hashing it to construct the name, the given path is used.

if there are multiple engine files corresponding to different subgraphs, how would the user specify which engine corresponds to which subgraph? or do we assume if the user specifies an engine file, there is a single graph (meaning native TensorRT supports the entire onnx model) ?

jywu-msft commented 1 year ago

+@chilo-ms FYI on this thread we are working with @gedoensmax to see how to make trtexec more interoperable with onnxruntime-trt. Appreciate others' feedback/suggestions as well!

jywu-msft commented 1 year ago

I second @fxmarty furthermore providing an explicit engine path instead of a directory could enable following cases:

  1. File existst: Engine is loaded without any checks and also does not require deserialization of the ONNX file and optimization which can be time consuming during setup:
  2. Files does not exist: Usual behavior of creating the engine from an ONNX file but instead of hashing it to construct the name, the given path is used.

if there are multiple engine files corresponding to different subgraphs, how would the user specify which engine corresponds to which subgraph? or do we assume if the user specifies an engine file, there is a single graph (meaning native TensorRT supports the entire onnx model) ?

I suppose if trtexec generated the engine file, that means TensorRT fully supports the onnx graph. my main concern is whether we can blindly load an engine file. TRT engine files aren't portable and what if it wasn't even generated from an ONNX graph to begin with? the hashing in onnxruntime aims to provide some safety checks to ensure the engine matches an ONNX graph/subgraph, and the versions of TensorRT, ORT, CUDA etc. are the same.

fxmarty commented 1 year ago

@jywu-msft Could we expect the user to provide a valid TRT engine file (i.e., generated on the same device for example)?

gedoensmax commented 3 months ago

This is now possible in 1.18.0 please consult the documentation on how to embed an engine into an ONNX file. https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#tensorrt-ep-caches

Here is a python script that is able to embed an externally (trtexec) compiled engine file: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py#L156-L187