microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.54k stars 2.91k forks source link

[Performance] Model converted to mixed precision results in higher latency #15490

Open brevity2021 opened 1 year ago

brevity2021 commented 1 year ago

Describe the issue

Hi,

I've tried to convert a Pegasus model to ONNX with mixed precision, but it results in higher latency than using ONNX + fp32, with IOBinding on GPU. The ONNX+fp32 has 20-30% latency improvement over Pytorch (Huggingface) implementation. After using convert_float_to_float16 to convert part of the onnx model to fp16, the latency is slightly higher than the Pytorch implementation.

I've checked the ONNX graphs and the mixed precision graph added thousands of cast nodes between fp32 and fp16, so I am wondering whether this is the reason of latency increase. I've attached the encoder graphs for reference (the decoder situation is similar)

Shall I try to avoid adding cast nodes between fp32<->fp16 as much as possible? Or are there any tools to avoid the cast? Would really appreciate if you could share some advice @tianleiwu

To reproduce

Conversion: The original ONNX + fp32 model: convert the Pytorch model using torch.onnx and use onnxsim to simplify.

Convert to mixed precision: Convert the ONNX to mixed precision using convert_float_to_float16, then do topological_sort() in OnnxModel, then use onnxsim to simplify as above.

Encoder: Call convert_float_to_float16 with the op_block_list (other are default parameters): op_block_list= [all op_type in the model except ["Identity", "Constant", "Gather", "Shape"]]

which converts the logit output to fp16, and the initializer embedding/position embedding weights, but there are many fp32<->fp16 cast nodes inside the graph.

Encoder graph (fp32) Encoder graph(mixed precision)

Only the embedding token/position weights in the initializer are converted to float16, but this procedure adds thousands of casts, e.g., layers.0/self_attn_layer_norm/ReduceMean_output_0 = Cast [ to = 10 ] (%/layers.0/self_attn_layer_norm/ReduceMean_output_cast_0)

/layers.0/self_attn_layer_norm/Sub_input_cast_1 = Cast [ to = 1 ] (%/layers.0/self_attn_layer_norm/ReduceMean_output_0)

Urgency

No response

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

14.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

No response

Model File

No response

Is this a quantized model?

Yes

wschin commented 1 year ago

You're right. That many Casts slows the entire computation down and it's not an expected behavior. I am not quantization expert, so I tag the team for better ideas.

brevity2021 commented 1 year ago

Thank you! That would be really helpful.

tianleiwu commented 1 year ago

@brevity2021,

The recommended way for mixed precision conversion is like the following: (1) Export fp32 model to onnx using torch.onnx.export (2) Optimize onnx model with transformer optimization tool (https://onnxruntime.ai/docs/performance/transformers-optimization.html). For transformers model, make sure the Attention or MultiHeadAttention is fused. (3) Convert optimized FP32 model to mixed precision by calling convert_float_to_float16.

Note that op_block_list is operators that computed in FP32. Normally, you can add LayerNormalization/SkipLayerNormalization and activation like Gelu/FastGelu to the list.

brevity2021 commented 1 year ago

Thank you @tianleiwu for your reply!

However, the optimization tool you suggested does not seem to support any encoder-decoder model according to this README page. Are you sure we can directly use it for an encoder-decoder model like Pegasus?

For the op_block_list, I was using [all op_type in the model except ["Identity", "Constant", "Gather", "Shape"]], since I consider them as "safe" operators.

Most of the operations in my model with layer norms are "Pow"/"Multiply"/"Div" operators, so I'm trying to exclude operators not supporting float16 in https://github.com/microsoft/onnxruntime/blob/main/docs/OperatorKernels.md#cudaexecutionprovider. Unfortunately the operator you suggested (LayerNormalization/SkipLayerNormalization/Gelu/FastGelu) doesn't exist in my model..

Maybe instead of op_block_list, I should use more explicit "node_block_list" to avoid casting nodes?

tianleiwu commented 1 year ago

@brevity2021, the supported model list is not up to date. We support encoder-decoder models like bart and t5 (T5 need nightly package). You can try use --model_type bart to see whether it could help.

python -m onnxruntime.transformers.optimizer --help
python -m onnxruntime.transformers.optimizer --input pegasus.onnx --output pegasus_opt.onnx --model_type bart

Welcome to contribute for graph fusion for pegasus model if you found some part like self/crosss attention is not fused!

op_block_list shall be as minimal as possible since those operators will be computed in fp32.