microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.54k stars 2.91k forks source link

Cast f32 -> bf16 -> f32 does not work as expected for graph inputs #9915

Open max-ku opened 2 years ago

max-ku commented 2 years ago

We have a sequence of back-to-back Cast operators casting from float to bfloat16 and then back to float, we expect values to be truncated or rounded to bfloat16 precision. However ORT does that only for graph initializers, and not for graph inputs, which remain the same (not truncated or rounded after f32 -> bf16 ->f32 cast sequence).

System information

Reproduction instructions

import sys
import argparse
import pathlib

import onnxruntime
import numpy as np

sess_options = onnxruntime.SessionOptions()
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL

session = onnxruntime.InferenceSession(sys.argv[1], sess_options)

session.get_modelmeta()

model_inputs = {}

model_inputs['cast0_input'] = np.array([0.333333343267440796], dtype=np.float32)

results = session.run([], model_inputs)

for i, output in enumerate(session.get_outputs()):
  pathlib.Path("output.bin." + str(i)).write_bytes(results[i])

repro.zip

Expected behavior

We expect graph input values to be truncated or rounded to bfloat16 precision, however it does not happen. It only works for graph initializers.

Workaround

If Identity node is inserted in between Cast nodes, Cast Ops work as expected.

askhade commented 2 years ago

Cast transformer is optimizing these casts from your model. ORT has a required cast transformer which is run for every model and

sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL

does not affect this transformer. You will need to use the workaround for this scenario. Are you seeing perf degradation by adding the Identity node?

max-ku commented 2 years ago

@askhade Thank you for clarification.

askhade commented 2 years ago

You can find the documentation here: https://onnxruntime.ai/docs/performance/graph-optimizations.html#graph-optimizations-in-onnx-runtime

The initializers dont get affected because they get constant folded. The cast transformer is used to insert cast nodes in case a node/op cannot be assigned to any execution provider because the execution provider does not support the data type. The most typical use case is inserting cast for fp16 to fp32 because for a certain operator fp16 is not supported. The cast transformer then runs a second iteration through the graph to remove redundant casts. Your cast nodes are getting removed because they are redundant fp32->bf16->fp32, I will see whether we can support your case but in the meanwhile the work around would be best.

tianleiwu commented 2 years ago

It is related to https://github.com/microsoft/onnxruntime/issues/8787, which seems to be a bug in optimizer. Also, it cannot be disabled by graph optimization level.

stale[bot] commented 2 years ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.