microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.16k stars 2.86k forks source link

AssertionError on num_heads > 0 for bert with specific optimization config #17254

Open trajepl opened 1 year ago

trajepl commented 1 year ago

Describe the issue

For following configs, the optimization failed with assertion error on num_heads>0 "transformer_optimization": { "model_type": "bert", "opt_level": 0, "use_gpu": false, "only_onnxruntime": false, "float16": false, },

But num_heads/hidden_size should be inferred automatically. image

Log Tracks: /home/jiapli/venv/lib/python3.8/site-packages/onnxruntime/transformers/optimizer.py:294 in │ │ optimize_model │ │ │ │ 291 │ if only_onnxruntime: │ │ 292 │ │ optimizer = optimizer_class(model, num_heads, hidden_size) │ │ 293 │ else: │ │ ❱ 294 │ │ optimizer = optimize_by_fusion(model, model_type, num_heads, hidden_size, optimi │ │ 295 │ │ │ 296 │ # Remove the temporary model. │ │ 297 │ if temp_model_path: │ │ │ │ /home/jiapli/venv/lib/python3.8/site-packages/onnxruntime/transformers/optimizer.py:180 in │ │ optimize_by_fusion │ │ │ │ 177 │ │ │ 178 │ optimizer = optimizer_class(model, num_heads, hidden_size) │ │ 179 │ │ │ ❱ 180 │ optimizer.optimize(optimization_options) │ │ 181 │ │ │ 182 │ optimizer.topological_sort() │ │ 183 │ │ │ │ /home/jiapli/venv/lib/python3.8/site-packages/onnxruntime/transformers/models/gpt2/../../onnx_mo │ │ del_bert.py:397 in optimize │ │ │ │ 394 │ │ │ │ ) │ │ 395 │ │ │ │ 396 │ │ if (options is None) or options.enable_attention: │ │ ❱ 397 │ │ │ self.fuse_attention() │ │ 398 │ │ │ │ 399 │ │ # Perform the MatMul fusion after the Attention fusion as we do not │ │ 400 │ │ # want to fuse the MatMuls inside the Attention subgraphs │ │ │ │ /home/jiapli/venv/lib/python3.8/site-packages/onnxruntime/transformers/models/gpt2/../../onnx_mo │ │ del_bert.py:65 in fuse_attention │ │ │ │ 62 │ │ self.utils = FusionUtils(self) │ │ 63 │ │ │ 64 │ def fuse_attention(self): │ │ ❱ 65 │ │ self.attention_fusion.apply() │ │ 66 │ │ # Only relevant in models with Q-DQ nodes │ │ 67 │ │ self.qordered_attentionfusion.apply() │ │ 68 │ │ │ │ /home/jiapli/venv/lib/python3.8/site-packages/onnxruntime/transformers/models/gpt2/../../fusion │ │ base.py:49 in apply │ │ │ │ 46 │ │ │ │ if graph is None: │ │ 47 │ │ │ │ │ raise Exception("Can not find node in any graphs") │ │ 48 │ │ │ │ self.this_graph_name = graph.name │ │ ❱ 49 │ │ │ │ self.fuse(node, input_name_to_nodes, output_name_to_node) │ │ 50 │ │ │ │ 51 │ │ op_list = [node.op_type for node in self.nodes_to_add] │ │ 52 │ │ if self.fusedcount: │ │ │ │ /home/jiapli/venv/lib/python3.8/site-packages/onnxruntime/transformers/models/gpt2/../../fusion │ │ attention.py:1120 in fuse │ │ │ │ 1117 │ │ │ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q) │ │ 1118 │ │ │ # number of heads are same for all the paths, hence to create attention node │ │ 1119 │ │ │ # the input_hidden_size represents the input hidden size, this is used as ne │ │ ❱ 1120 │ │ │ new_node = self.create_attention_node( │ │ 1121 │ │ │ │ mask_index, │ │ 1122 │ │ │ │ matmul_q, │ │ 1123 │ │ │ │ matmulk, │ │ │ │ /home/jiapli/venv/lib/python3.8/site-packages/onnxruntime/transformers/models/gpt2/../../fusion │ │ attention.py:695 in create_attention_node │ │ │ │ 692 │ │ Returns: │ │ 693 │ │ │ Union[NodeProto, None]: the node created or None if failed. │ │ 694 │ │ """ │ │ ❱ 695 │ │ assert num_heads > 0 │ │ 696 │ │ │ │ 697 │ │ if hidden_size > 0 and (hidden_size % num_heads) != 0: │ │ 698 │ │ │ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of h

To reproduce

Model: converted from Intel/bert-base-uncased-mrpc run transformer optimization with arguments: "transformer_optimization": { "model_type": "bert", "opt_level": 0, "use_gpu": false, "only_onnxruntime": false, "float16": false, },

Urgency

Not block our project as we can always given the number of num_heads and hidden_size to fix above error.

Platform

Linux

OS Version

Ubuntu 20.04.5 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

tianleiwu commented 1 year ago

@trajepl, thanks for raising the issue. Let me send a fix soon.