Closed AlexMontgomerie closed 1 year ago
Did you ever find a solution to this issue, I'm encountering the same problem.
Here's some code that I've used in my project, hope it helps
import copy
import onnx
import onnxruntime
import onnx.utils
import onnx.numpy_helper
from onnx import version_converter
import onnxoptimizer as optimizer
from onnx.tools import update_model_dims
from itertools import repeat
from collections.abc import Iterable
import numpy as np
def get_model_input(model, name):
for node in model.graph.input:
if node.name == name: # exact match
return node
def get_model_initializer(model, name, to_tensor=True):
for node in model.graph.initializer:
if node.name == name: # exact match
if to_tensor:
return onnx.numpy_helper.to_array(node)
else:
return node
def convert_matmul_to_gemm(model):
# iterate over nodes in the graph
for index, node in enumerate(model.graph.node):
if node.op_type == "MatMul":
# update the weights
init = get_model_initializer(model, node.input[1], to_tensor=False)
init_index = list(model.graph.initializer).index(init)
weights = onnx.numpy_helper.to_array(init)
weights = np.swapaxes(weights,0,1)
new_init = onnx.helper.make_tensor(
name=node.input[1],
data_type=init.data_type,
dims=weights.shape,
vals=weights.flatten().tolist())
# update weight's value info
init_value_info = get_model_input(model, node.input[1])
init_value_info_index = list(model.graph.input).index(init_value_info)
new_init_value_info = onnx.helper.make_tensor_value_info(
node.input[1],
onnx.TensorProto.FLOAT,
weights.shape)
# update the graph
model.graph.initializer.remove(init)
model.graph.initializer.insert(init_index,new_init)
model.graph.input.remove(init_value_info)
model.graph.input.insert(init_value_info_index, new_init_value_info)
# add an empty bias term
new_bias = onnx.helper.make_tensor(
name="_".join([node.input[1],"bias"]),
data_type=init.data_type,
dims=(weights.shape[0],),
vals=np.zeros(weights.shape[0]).flatten().tolist())
new_bias_value_info = onnx.helper.make_tensor_value_info(
new_bias.name,
onnx.TensorProto.FLOAT,
[weights.shape[0]])
# update the graph
model.graph.initializer.insert(-1,new_bias)
model.graph.input.insert(-1,new_bias_value_info)
# create a new matmul node
new_node = onnx.helper.make_node(
"Gemm",
name=node.name,
inputs=[*node.input, "_".join([node.input[1],"bias"])],
outputs=node.output,
alpha=1.0,
beta=1.0,
transA=0,
transB=1
)
# remove old node and add new one
model.graph.node.remove(node)
model.graph.node.insert(index, new_node)
# return the new model
return model
The pattern of fuse_matmul_add_bias_into_gemm
is that:
// Z = MatMul(X, Y)
// A = Z + Bias
// After:
// A = Gemm(X, Y, Bias)
The shape of X is (M,K) , the shape of Y is (K,N) and the rank of Bias can be 1 or 2.
I'm coming across an issue with the
fuse_matmul_add_bias_into_gemm
when there is just a matmul node and no add or bias node after. I've written a quick example of what the issue is below. What I am expecting (correct me if I'm wrong) is the MatMul node to be converted to a GEMM node instead.