microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
13.92k stars 2.81k forks source link

[Performance] Get nan value when I block all the node in fp16 conversion #21345

Closed jinhonglu closed 1 month ago

jinhonglu commented 1 month ago

Describe the issue

Since the mixed precision conversion is not working well, I tried to figure out which nodes to be converted fp16 and get the best performance. Thus, I tried to block all the nodes at first. However, I got nan output from the fp16 model. Ideally, this fp16 model should perform exactly as the fp32 model.

To reproduce

model = onnx.load("my_fp32_onnxmodel") list = [] includelist = [] for node in model.graph.node: list.append(node.name) model_fp16 = float16.convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=True, disable_shape_infer=False, node_blocklist=list ) onnx.save(model_fp16, "fp16_model.onnx") ort_session = onnxruntime.InferenceSession('fp16_model.onnx', providers=["CUDAExecutionProvider"])

batch_output_mask_fp16 = torch.tensor(ort_session.run(None, ort_inputs)[0]) print(sum(batch_output_mask_fp16.cpu().numpy())) print(sum(batch_output_mask.cpu().numpy())) print(np.abs(batch_output_mask_fp16.cpu().numpy() - batch_output_mask.cpu().numpy()).mean())

Urgency

No response

Platform

Linux

OS Version

Ubuntu 22.04.4 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

onnxruntime-gpu 1.18.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

12.5

Model File

No response

Is this a quantized model?

No

tianleiwu commented 1 month ago

The solution is that some nodes run in fp32 instead of fp16. It is easy to find out which nodes shall be kept in fp32 by looking at the output statistics. If the value is not in range [-65504,66504], it will need fp32.

You can build from source, add --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=1. Then set some environment variables before running fp32 onnx model with onnxruntime:

export ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA=1
export ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA=1
export ORT_DEBUG_NODE_IO_DUMP_STATISTICS_DATA=1

In the console output, it will show the statistics of each node.

jinhonglu commented 1 month ago

66504

Hi, thanks for your reply.

I have followed your way to figure out some of the nodes that exceeded the range [-65504, 66504] and I blocked them from being converted from fp32 to fp16 during the conversion.

However, the output result of the fp16 model is still far away from the fp32 model.

I looked into the statistics again, and then I observed some of the values in the output nodes turned into Inf and I realised this would be the problem of the cast layer because those values exceeded the range of fp16. Do you have any ideas to resolve?

tianleiwu commented 1 month ago

@jinhonglu , keep_io_types parameter can be a list of input and output names that need to keep fp32 instead of converting to fp16: You can add the output name to the list.

jinhonglu commented 1 month ago

@jinhonglu , keep_io_types parameter can be a list of input and output names that need to keep fp32 instead of converting to fp16: You can add the output name to the list.

Thanks, the current fp16 model has a low difference from the fp32 model. (The difference is about [0.0004, 0.001] now, before is 0.29)

However, in inference, the result of the fp16 model is still far away from the fp32 model with the actual data.

Btw, I am doing an audio task. I have to do a masking method with the model's output and then apply istft.

I suspect the difference for an audio task is still too large. Anyway for me to adjust the conversion to get a smaller difference?

The current conversion code is below

model_fp16 = float16.convert_float_to_float16(model, min_positive_val=5.96e-08, max_finite_val=65504.0, keep_io_types=io_list_, disable_shape_infer=False, )

jinhonglu commented 1 month ago

later I found out that as the current fp16 model is converted by CPU, I ran the fp16 model through CPU for inference, the result is mostly the same as the fp32 model.

But when I ran the fp16 model in GPU, the result would be totally different.

Then I rebuild the onnxruntime in GPU to support CUDAProviders, I reran the fp32 model to find out the node names (put in the keep_io_types), converted to fp16 and ran with the onnxruntime-gpu.

The statistics in fp16 again have Inf values.

what is wrong with this?

tianleiwu commented 1 month ago

@jinhonglu, could you use same input tensors, and dump CPU inference stdout to one text file, and redirect GPU inference console output to another text file, and share the text files? We can compare the results to find out which node / operator causes the difference.

jinhonglu commented 1 month ago

fp16_model_cpu_stat.txt fp16_model_cuda_stat.txt

both files are uploaded.

For your convenience, I have figured out those node/operator exceeded the range for both operations.

Under CPU, ['/band_split/to_features.25/to_features.25.0/ReduceSum_output_0', '/layers.0.1/layers.0.0/norm/ReduceSum_output_0', '/layers.0.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.1.0/layers.0.0/norm/Pow_output_0', '/layers.1.0/layers.0.0/norm/ReduceSum_output_0', '/layers.1.0/layers.0.1/net/net.0/Pow_output_0', '/layers.1.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.1.1/layers.0.0/norm/Pow_output_0', '/layers.1.1/layers.0.0/norm/ReduceSum_output_0', '/layers.1.1/layers.0.1/net/net.0/Pow_output_0', '/layers.1.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.2.0/layers.0.0/norm/Pow_output_0', '/layers.2.0/layers.0.0/norm/ReduceSum_output_0', '/layers.2.0/layers.0.1/net/net.0/Pow_output_0', '/layers.2.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.2.1/layers.0.0/norm/Pow_output_0', '/layers.2.1/layers.0.0/norm/ReduceSum_output_0', '/layers.2.1/layers.0.1/net/net.0/Pow_output_0', '/layers.2.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.3.0/layers.0.0/norm/Pow_output_0', '/layers.3.0/layers.0.0/norm/ReduceSum_output_0', '/layers.3.0/layers.0.1/net/net.0/Pow_output_0', '/layers.3.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.3.1/layers.0.0/norm/Pow_output_0', '/layers.3.1/layers.0.0/norm/ReduceSum_output_0', '/layers.3.1/layers.0.1/net/net.0/Pow_output_0', '/layers.3.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.4.0/layers.0.0/norm/Pow_output_0', '/layers.4.0/layers.0.0/norm/ReduceSum_output_0', '/layers.4.0/layers.0.1/net/net.0/Pow_output_0', '/layers.4.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.4.1/layers.0.0/norm/Pow_output_0', '/layers.4.1/layers.0.0/norm/ReduceSum_output_0', '/layers.4.1/layers.0.1/net/net.0/Pow_output_0', '/layers.4.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.5.0/layers.0.0/norm/Pow_output_0', '/layers.5.0/layers.0.0/norm/ReduceSum_output_0', '/layers.5.0/layers.0.1/net/net.0/Pow_output_0', '/layers.5.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.5.1/layers.0.0/norm/Pow_output_0', '/layers.5.1/layers.0.0/norm/ReduceSum_output_0', '/layers.5.1/layers.0.1/net/net.0/Pow_output_0', '/layers.5.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.6.0/layers.0.0/norm/Pow_output_0', '/layers.6.0/layers.0.0/norm/ReduceSum_output_0', '/layers.6.0/layers.0.1/net/net.0/Pow_output_0', '/layers.6.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.6.1/layers.0.0/norm/Pow_output_0', '/layers.6.1/layers.0.0/norm/ReduceSum_output_0', '/layers.6.1/layers.0.1/net/net.0/Pow_output_0', '/layers.6.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.7.0/layers.0.0/norm/Pow_output_0', '/layers.7.0/layers.0.0/norm/ReduceSum_output_0', '/layers.7.0/layers.0.1/net/net.0/Pow_output_0', '/layers.7.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.7.1/layers.0.0/norm/Pow_output_0', '/layers.7.1/layers.0.0/norm/ReduceSum_output_0', '/layers.7.1/layers.0.1/net/net.0/Pow_output_0', '/layers.7.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.8.0/layers.0.0/norm/Pow_output_0', '/layers.8.0/layers.0.0/norm/ReduceSum_output_0', '/layers.8.0/layers.0.1/net/net.0/Pow_output_0', '/layers.8.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.8.1/layers.0.0/norm/Pow_output_0', '/layers.8.1/layers.0.0/norm/ReduceSum_output_0', '/layers.8.1/layers.0.1/net/net.0/Pow_output_0', '/layers.8.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.9.0/layers.0.0/norm/Pow_output_0', '/layers.9.0/layers.0.0/norm/ReduceSum_output_0', '/layers.9.0/layers.0.1/net/net.0/Pow_output_0', '/layers.9.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.9.1/layers.0.0/norm/Pow_output_0', '/layers.9.1/layers.0.0/norm/ReduceSum_output_0', '/layers.9.1/layers.0.1/net/net.0/Pow_output_0', '/layers.9.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.10.0/layers.0.0/norm/Pow_output_0', '/layers.10.0/layers.0.0/norm/ReduceSum_output_0', '/layers.10.0/layers.0.1/net/net.0/Pow_output_0', '/layers.10.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.10.1/layers.0.0/norm/Pow_output_0', '/layers.10.1/layers.0.0/norm/ReduceSum_output_0', '/layers.10.1/layers.0.1/net/net.0/Pow_output_0', '/layers.10.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.11.0/layers.0.0/norm/Pow_output_0', '/layers.11.0/layers.0.0/norm/ReduceSum_output_0', '/layers.11.0/layers.0.1/net/net.0/Pow_output_0', '/layers.11.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.11.1/layers.0.0/norm/Pow_output_0', '/layers.11.1/layers.0.0/norm/ReduceSum_output_0', '/layers.11.1/layers.0.1/net/net.0/Pow_output_0', '/layers.11.1/layers.0.1/net/net.0/ReduceSum_output_0', '/final_norm/Pow_output_0', '/final_norm/ReduceSum_output_0']

Under CUDA, ['/band_split/to_features.25/to_features.25.0/ReduceSum_output_0', '/band_split/to_features.25/to_features.25.0/Pow_1_output_0', '/band_split/to_features.25/to_features.25.0/Clip_output_0', '/band_split/to_features.25/to_features.25.0/Expand_output_0', '/layers.0.0/layers.0.0/norm/Pow_output_0', '/layers.0.0/layers.0.0/norm/ReduceSum_output_0', '/layers.0.0/layers.0.0/norm/Pow_1_output_0', '/layers.0.0/layers.0.0/norm/Clip_output_0', '/layers.0.0/layers.0.0/norm/Expand_output_0', '/layers.0.0/layers.0.1/net/net.0/Pow_output_0', '/layers.0.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.0.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.0.0/layers.0.1/net/net.0/Clip_output_0', '/layers.0.0/layers.0.1/net/net.0/Expand_output_0', '/layers.0.1/layers.0.0/norm/Pow_output_0', '/layers.0.1/layers.0.0/norm/ReduceSum_output_0', '/layers.0.1/layers.0.0/norm/Pow_1_output_0', '/layers.0.1/layers.0.0/norm/Clip_output_0', '/layers.0.1/layers.0.0/norm/Expand_output_0', '/layers.0.1/layers.0.1/net/net.0/Pow_output_0', '/layers.0.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.0.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.0.1/layers.0.1/net/net.0/Clip_output_0', '/layers.0.1/layers.0.1/net/net.0/Expand_output_0', '/layers.1.0/layers.0.0/norm/Pow_output_0', '/layers.1.0/layers.0.0/norm/ReduceSum_output_0', '/layers.1.0/layers.0.0/norm/Pow_1_output_0', '/layers.1.0/layers.0.0/norm/Clip_output_0', '/layers.1.0/layers.0.0/norm/Expand_output_0', '/layers.1.0/layers.0.1/net/net.0/Pow_output_0', '/layers.1.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.1.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.1.0/layers.0.1/net/net.0/Clip_output_0', '/layers.1.0/layers.0.1/net/net.0/Expand_output_0', '/layers.1.1/layers.0.0/norm/Pow_output_0', '/layers.1.1/layers.0.0/norm/ReduceSum_output_0', '/layers.1.1/layers.0.0/norm/Pow_1_output_0', '/layers.1.1/layers.0.0/norm/Clip_output_0', '/layers.1.1/layers.0.0/norm/Expand_output_0', '/layers.1.1/layers.0.1/net/net.0/Pow_output_0', '/layers.1.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.1.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.1.1/layers.0.1/net/net.0/Clip_output_0', '/layers.1.1/layers.0.1/net/net.0/Expand_output_0', '/layers.2.0/layers.0.0/norm/Pow_output_0', '/layers.2.0/layers.0.0/norm/ReduceSum_output_0', '/layers.2.0/layers.0.0/norm/Pow_1_output_0', '/layers.2.0/layers.0.0/norm/Clip_output_0', '/layers.2.0/layers.0.0/norm/Expand_output_0', '/layers.2.0/layers.0.1/net/net.0/Pow_output_0', '/layers.2.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.2.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.2.0/layers.0.1/net/net.0/Clip_output_0', '/layers.2.0/layers.0.1/net/net.0/Expand_output_0', '/layers.2.1/layers.0.0/norm/Pow_output_0', '/layers.2.1/layers.0.0/norm/ReduceSum_output_0', '/layers.2.1/layers.0.0/norm/Pow_1_output_0', '/layers.2.1/layers.0.0/norm/Clip_output_0', '/layers.2.1/layers.0.0/norm/Expand_output_0', '/layers.2.1/layers.0.1/net/net.0/Pow_output_0', '/layers.2.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.2.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.2.1/layers.0.1/net/net.0/Clip_output_0', '/layers.2.1/layers.0.1/net/net.0/Expand_output_0', '/layers.3.0/layers.0.0/norm/Pow_output_0', '/layers.3.0/layers.0.0/norm/ReduceSum_output_0', '/layers.3.0/layers.0.0/norm/Pow_1_output_0', '/layers.3.0/layers.0.0/norm/Clip_output_0', '/layers.3.0/layers.0.0/norm/Expand_output_0', '/layers.3.0/layers.0.1/net/net.0/Pow_output_0', '/layers.3.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.3.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.3.0/layers.0.1/net/net.0/Clip_output_0', '/layers.3.0/layers.0.1/net/net.0/Expand_output_0', '/layers.3.1/layers.0.0/norm/Pow_output_0', '/layers.3.1/layers.0.0/norm/ReduceSum_output_0', '/layers.3.1/layers.0.0/norm/Pow_1_output_0', '/layers.3.1/layers.0.0/norm/Clip_output_0', '/layers.3.1/layers.0.0/norm/Expand_output_0', '/layers.3.1/layers.0.1/net/net.0/Pow_output_0', '/layers.3.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.3.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.3.1/layers.0.1/net/net.0/Clip_output_0', '/layers.3.1/layers.0.1/net/net.0/Expand_output_0', '/layers.4.0/layers.0.0/norm/Pow_output_0', '/layers.4.0/layers.0.0/norm/ReduceSum_output_0', '/layers.4.0/layers.0.0/norm/Pow_1_output_0', '/layers.4.0/layers.0.0/norm/Clip_output_0', '/layers.4.0/layers.0.0/norm/Expand_output_0', '/layers.4.0/layers.0.1/net/net.0/Pow_output_0', '/layers.4.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.4.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.4.0/layers.0.1/net/net.0/Clip_output_0', '/layers.4.0/layers.0.1/net/net.0/Expand_output_0', '/layers.4.1/layers.0.0/norm/Pow_output_0', '/layers.4.1/layers.0.0/norm/ReduceSum_output_0', '/layers.4.1/layers.0.0/norm/Pow_1_output_0', '/layers.4.1/layers.0.0/norm/Clip_output_0', '/layers.4.1/layers.0.0/norm/Expand_output_0', '/layers.4.1/layers.0.1/net/net.0/Pow_output_0', '/layers.4.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.4.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.4.1/layers.0.1/net/net.0/Clip_output_0', '/layers.4.1/layers.0.1/net/net.0/Expand_output_0', '/layers.5.0/layers.0.0/norm/Pow_output_0', '/layers.5.0/layers.0.0/norm/ReduceSum_output_0', '/layers.5.0/layers.0.0/norm/Pow_1_output_0', '/layers.5.0/layers.0.0/norm/Clip_output_0', '/layers.5.0/layers.0.0/norm/Expand_output_0', '/layers.5.0/layers.0.1/net/net.0/Pow_output_0', '/layers.5.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.5.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.5.0/layers.0.1/net/net.0/Clip_output_0', '/layers.5.0/layers.0.1/net/net.0/Expand_output_0', '/layers.5.1/layers.0.0/norm/Pow_output_0', '/layers.5.1/layers.0.0/norm/ReduceSum_output_0', '/layers.5.1/layers.0.0/norm/Pow_1_output_0', '/layers.5.1/layers.0.0/norm/Clip_output_0', '/layers.5.1/layers.0.0/norm/Expand_output_0', '/layers.5.1/layers.0.1/net/net.0/Pow_output_0', '/layers.5.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.5.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.5.1/layers.0.1/net/net.0/Clip_output_0', '/layers.5.1/layers.0.1/net/net.0/Expand_output_0', '/layers.6.0/layers.0.0/norm/Pow_output_0', '/layers.6.0/layers.0.0/norm/ReduceSum_output_0', '/layers.6.0/layers.0.0/norm/Pow_1_output_0', '/layers.6.0/layers.0.0/norm/Clip_output_0', '/layers.6.0/layers.0.0/norm/Expand_output_0', '/layers.6.0/layers.0.1/net/net.0/Pow_output_0', '/layers.6.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.6.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.6.0/layers.0.1/net/net.0/Clip_output_0', '/layers.6.0/layers.0.1/net/net.0/Expand_output_0', '/layers.6.1/layers.0.0/norm/Pow_output_0', '/layers.6.1/layers.0.0/norm/ReduceSum_output_0', '/layers.6.1/layers.0.0/norm/Pow_1_output_0', '/layers.6.1/layers.0.0/norm/Clip_output_0', '/layers.6.1/layers.0.0/norm/Expand_output_0', '/layers.6.1/layers.0.1/net/net.0/Pow_output_0', '/layers.6.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.6.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.6.1/layers.0.1/net/net.0/Clip_output_0', '/layers.6.1/layers.0.1/net/net.0/Expand_output_0', '/layers.7.0/layers.0.0/norm/Pow_output_0', '/layers.7.0/layers.0.0/norm/ReduceSum_output_0', '/layers.7.0/layers.0.0/norm/Pow_1_output_0', '/layers.7.0/layers.0.0/norm/Clip_output_0', '/layers.7.0/layers.0.0/norm/Expand_output_0', '/layers.7.0/layers.0.1/net/net.0/Pow_output_0', '/layers.7.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.7.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.7.0/layers.0.1/net/net.0/Clip_output_0', '/layers.7.0/layers.0.1/net/net.0/Expand_output_0', '/layers.7.1/layers.0.0/norm/Pow_output_0', '/layers.7.1/layers.0.0/norm/ReduceSum_output_0', '/layers.7.1/layers.0.0/norm/Pow_1_output_0', '/layers.7.1/layers.0.0/norm/Clip_output_0', '/layers.7.1/layers.0.0/norm/Expand_output_0', '/layers.7.1/layers.0.1/net/net.0/Pow_output_0', '/layers.7.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.7.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.7.1/layers.0.1/net/net.0/Clip_output_0', '/layers.7.1/layers.0.1/net/net.0/Expand_output_0', '/layers.8.0/layers.0.0/norm/Pow_output_0', '/layers.8.0/layers.0.0/norm/ReduceSum_output_0', '/layers.8.0/layers.0.0/norm/Pow_1_output_0', '/layers.8.0/layers.0.0/norm/Clip_output_0', '/layers.8.0/layers.0.0/norm/Expand_output_0', '/layers.8.0/layers.0.1/net/net.0/Pow_output_0', '/layers.8.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.8.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.8.0/layers.0.1/net/net.0/Clip_output_0', '/layers.8.0/layers.0.1/net/net.0/Expand_output_0', '/layers.8.1/layers.0.0/norm/Pow_output_0', '/layers.8.1/layers.0.0/norm/ReduceSum_output_0', '/layers.8.1/layers.0.0/norm/Pow_1_output_0', '/layers.8.1/layers.0.0/norm/Clip_output_0', '/layers.8.1/layers.0.0/norm/Expand_output_0', '/layers.8.1/layers.0.1/net/net.0/Pow_output_0', '/layers.8.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.8.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.8.1/layers.0.1/net/net.0/Clip_output_0', '/layers.8.1/layers.0.1/net/net.0/Expand_output_0', '/layers.9.0/layers.0.0/norm/Pow_output_0', '/layers.9.0/layers.0.0/norm/ReduceSum_output_0', '/layers.9.0/layers.0.0/norm/Pow_1_output_0', '/layers.9.0/layers.0.0/norm/Clip_output_0', '/layers.9.0/layers.0.0/norm/Expand_output_0', '/layers.9.0/layers.0.1/net/net.0/Pow_output_0', '/layers.9.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.9.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.9.0/layers.0.1/net/net.0/Clip_output_0', '/layers.9.0/layers.0.1/net/net.0/Expand_output_0', '/layers.9.1/layers.0.0/norm/Pow_output_0', '/layers.9.1/layers.0.0/norm/ReduceSum_output_0', '/layers.9.1/layers.0.0/norm/Pow_1_output_0', '/layers.9.1/layers.0.0/norm/Clip_output_0', '/layers.9.1/layers.0.0/norm/Expand_output_0', '/layers.9.1/layers.0.1/net/net.0/Pow_output_0', '/layers.9.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.9.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.9.1/layers.0.1/net/net.0/Clip_output_0', '/layers.9.1/layers.0.1/net/net.0/Expand_output_0', '/layers.10.0/layers.0.0/norm/Pow_output_0', '/layers.10.0/layers.0.0/norm/ReduceSum_output_0', '/layers.10.0/layers.0.0/norm/Pow_1_output_0', '/layers.10.0/layers.0.0/norm/Clip_output_0', '/layers.10.0/layers.0.0/norm/Expand_output_0', '/layers.10.0/layers.0.1/net/net.0/Pow_output_0', '/layers.10.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.10.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.10.0/layers.0.1/net/net.0/Clip_output_0', '/layers.10.0/layers.0.1/net/net.0/Expand_output_0', '/layers.10.1/layers.0.0/norm/Pow_output_0', '/layers.10.1/layers.0.0/norm/ReduceSum_output_0', '/layers.10.1/layers.0.0/norm/Pow_1_output_0', '/layers.10.1/layers.0.0/norm/Clip_output_0', '/layers.10.1/layers.0.0/norm/Expand_output_0', '/layers.10.1/layers.0.1/net/net.0/Pow_output_0', '/layers.10.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.10.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.10.1/layers.0.1/net/net.0/Clip_output_0', '/layers.10.1/layers.0.1/net/net.0/Expand_output_0', '/layers.11.0/layers.0.0/norm/Pow_output_0', '/layers.11.0/layers.0.0/norm/ReduceSum_output_0', '/layers.11.0/layers.0.0/norm/Pow_1_output_0', '/layers.11.0/layers.0.0/norm/Clip_output_0', '/layers.11.0/layers.0.0/norm/Expand_output_0', '/layers.11.0/layers.0.1/net/net.0/Pow_output_0', '/layers.11.0/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.11.0/layers.0.1/net/net.0/Pow_1_output_0', '/layers.11.0/layers.0.1/net/net.0/Clip_output_0', '/layers.11.0/layers.0.1/net/net.0/Expand_output_0', '/layers.11.1/layers.0.0/norm/Pow_output_0', '/layers.11.1/layers.0.0/norm/ReduceSum_output_0', '/layers.11.1/layers.0.0/norm/Pow_1_output_0', '/layers.11.1/layers.0.0/norm/Clip_output_0', '/layers.11.1/layers.0.0/norm/Expand_output_0', '/layers.11.1/layers.0.1/net/net.0/Pow_output_0', '/layers.11.1/layers.0.1/net/net.0/ReduceSum_output_0', '/layers.11.1/layers.0.1/net/net.0/Pow_1_output_0', '/layers.11.1/layers.0.1/net/net.0/Clip_output_0', '/layers.11.1/layers.0.1/net/net.0/Expand_output_0', '/final_norm/Pow_output_0', '/final_norm/ReduceSum_output_0', '/final_norm/Pow_1_output_0', '/final_norm/Clip_output_0', '/final_norm/Expand_output_0']

jinhonglu commented 1 month ago

When I look at the cuda txt file, I realise the first Inf occurs at the output name '/band_split/to_features.25/to_features.25.0/ReduceSum_output_0'

But, this node should not be restricted to fp16 as it is in 'keep_io_types' list.

Is it caused by the 'max_finite_val'?

tianleiwu commented 1 month ago

@jinhonglu, keep_io_types list is only for graph input and outputs. You can use other two parameters op_block_list (a list of operator names like ["ReduceSum"]) or node_block_list (a list of node names): https://github.com/microsoft/onnxruntime/blob/281ed8c12d2d2a3f5b683e6267aa0fca4d4add50/onnxruntime/python/tools/transformers/float16.py#L190-L192

jinhonglu commented 1 month ago

@tianleiwu

Currently the fp16 model is suffering from back-to-back cast operation (ReduceSumOp(fp32)->Ouput(fp32)->Cast(fp16)->Cast(fp32)->PowOp(fp32))

What I want is (ReduceSumOp(fp32)->Ouput(fp32)->PowOp(fp32))

I have noticed #8787 and #17953, and cast remover seems not to be taking effect when disabling all optimizers.

model_fp16 = float16.convert_float_to_float16(model, min_positive_val=5.96e-08, max_finite_val=65504.0, keep_io_types=iolist, op_block_list=float16.DEFAULT_OP_BLOCK_LIST, node_block_list=['/band_split/to_features.25/to_features.25.0/ReduceSum', '/band_split/to_features.25/to_features.25.0/Pow_1'], disable_shape_infer=False, )

opts = onnxruntime.SessionOptions() opts.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL ort_session = onnxruntime.InferenceSession('fp16.onnx', sess_options=opts, providers=["CUDAExecutionProvider"])

Anyways to disable these back-to-back Cast? I have tried to add the cast node name to the node_block_list, but there is no effect.

This seems to be resolved by this https://github.com/microsoft/onnxconverter-common/pull/286

tianleiwu commented 1 month ago

@jinhonglu, thanks for identifying the root cause. You can also follow this to walkaround it.

Let me add a same post-processing to float16.convert_float_to_float16 for next release 1.19.

jinhonglu commented 1 month ago

@tianleiwu thanks for your help.

Furthermore, do you have any experience converting mixed precision onnx model to a TensorRT engine?

I tried the TensorRT provider in onnxruntime with enabling 'trt_fp16_enbale' to run my above model, it seems that the engine builder forces all the nodes to be fp16 and is incompatible with mixed precision model.

tianleiwu commented 1 month ago

@jinhonglu, you can use fp32 onnx model to run TRT EP, only need set the trt_fp16_enbale flag in trt provider option. https://github.com/microsoft/onnxruntime/blob/a6c5e2cd20dd890f416806e0afbb3b5968030f4d/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py#L64

For optimization, the onnx model only need constant folding and shape inference since most optimizations will be done in engine building inside TRT. Do not over done it. Example code in the demo: https://github.com/microsoft/onnxruntime/blob/a6c5e2cd20dd890f416806e0afbb3b5968030f4d/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py#L451-L455 https://github.com/microsoft/onnxruntime/blob/a6c5e2cd20dd890f416806e0afbb3b5968030f4d/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py#L45

jinhonglu commented 1 month ago

@tianleiwu

I followed your instructions, but the result of the engine is not correct.

The following is my setting,

        optimized_path = model_path.split('.')[0] + '_optimized.onnx'
        if not os.path.exists(optimized_path):
            self.optimize_trt(model_path, optimized_path)
        session_options = onnxruntime.SessionOptions()
        session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
        providers = [
            ('TensorrtExecutionProvider', {
                'device_id': 0,  # Select GPU to execute
                'trt_fp16_enable': True,  # Enable FP16 precision for faster inference
                'trt_engine_cache_enable': True,
                'trt_dump_ep_context_model': True,
                'trt_ep_context_file_path': './model/checkpoint/onnx/',
                # 'trt_ep_context_embed_mode': False,
                'trt_layer_norm_fp32_fallback': True
            })
        ]
        ort_session = onnxruntime.InferenceSession(optimized_path, session_options, providers=providers)
tianleiwu commented 1 month ago

@jinhonglu, please look at https://github.com/NVIDIA/TensorRT/issues/2691 for other options like plugin or adding constraint using Polygraphy cli to resolve tensorrt fp16 precision issue.

For example, when you created a custom trt engine using the second approach, you can embed the engine into an ONNX file: https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#tensorrt-ep-caches

Here is a python script to embed an externally (trtexec) compiled engine file: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py#L156-L187

jinhonglu commented 1 month ago

@tianleiwu

I have followed https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/examples/cli/run/08_adding_precision_constraints

image

And I have passed the comparison test between the onnx_32 output and fp16 engine with my own constraints using polygraph

image

Below is the postprocess constraint I created.

  def postprocess(network):
      """The below function traverses the parsed network and constrains precisions
      for specific layers to FP32.

      See examples/cli/run/04_defining_a_tensorrt_network_or_config_manually
      for more examples using network scripts in Polygraphy.
      """
      for i in range(network.num_layers):
          # previous_layer = network.get_layer(i-1)
          layer = network.get_layer(i)
          # next_layer = network.get_layer(i + 1)
          # Set computation precision for Add and Sub layer to FP32
          if 'Pow' in layer.name or 'ReduceSum' in layer.name:
              previous_layer = network.get_layer(i - 1)
              next_layer = network.get_layer(i + 1)

              layer.precision = trt.float32
              next_layer.precision = trt.float32
              layer.set_output_type(0, trt.float32)

The ops I keep to float32 is based on the onnx analysis above (Or should I do another TensorRT analysis instead?)

After that, I converted and got the engine using polygraphy

polygraphy convert fp32.onnx --precision-constraints obey --fp16 --trt-network-postprocess-script=build_constraint.py -o fp16.engine

However, The inference result is still far away from the onnx model.