onnx / optimizer

Actively maintained ONNX Optimizer
Apache License 2.0
650 stars 90 forks source link

fuse_matmul_add_bias_into_gemm: not working for just MatMul #58

Closed AlexMontgomerie closed 1 year ago

AlexMontgomerie commented 3 years ago

I'm coming across an issue with the fuse_matmul_add_bias_into_gemm when there is just a matmul node and no add or bias node after. I've written a quick example of what the issue is below. What I am expecting (correct me if I'm wrong) is the MatMul node to be converted to a GEMM node instead.

from onnx import helper
from onnx import checker, helper, ModelProto, TensorProto, GraphProto, NodeProto, shape_inference
import onnxoptimizer

matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
add = helper.make_node("Add", ["Z", "B"], ["A"])
graph = helper.make_graph(
    [matmul],
    "test",
    [helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 10)),
     helper.make_tensor_value_info("Y", TensorProto.FLOAT, (10, 16))],
    [helper.make_tensor_value_info("Z", TensorProto.FLOAT, (32, 16))]
)

model = helper.make_model(graph)
optimized_model = onnxoptimizer.optimize(
    model, passes=["fuse_matmul_add_bias_into_gemm"])

print(optimized_model)

assert len(list(optimized_model.graph.node)) == 1
assert optimized_model.graph.node[0].op_type == "Gemm"
B-Willems commented 2 years ago

Did you ever find a solution to this issue, I'm encountering the same problem.

AlexMontgomerie commented 2 years ago

Here's some code that I've used in my project, hope it helps

import copy
import onnx
import onnxruntime
import onnx.utils
import onnx.numpy_helper
from onnx import version_converter
import onnxoptimizer as optimizer
from onnx.tools import update_model_dims
from itertools import repeat
from collections.abc import Iterable
import numpy as np

def get_model_input(model, name):
    for node in model.graph.input:
        if node.name == name: # exact match
            return node

def get_model_initializer(model, name, to_tensor=True):
    for node in model.graph.initializer:
        if node.name == name: # exact match
            if to_tensor:
                return onnx.numpy_helper.to_array(node)
            else:
                return node

def convert_matmul_to_gemm(model):
    # iterate over nodes in the graph
    for index, node in enumerate(model.graph.node):
        if node.op_type == "MatMul":
            # update the weights
            init = get_model_initializer(model, node.input[1], to_tensor=False)
            init_index = list(model.graph.initializer).index(init)
            weights = onnx.numpy_helper.to_array(init)
            weights = np.swapaxes(weights,0,1)
            new_init = onnx.helper.make_tensor(
                name=node.input[1],
                data_type=init.data_type,
                dims=weights.shape,
                vals=weights.flatten().tolist())
            # update weight's value info
            init_value_info = get_model_input(model, node.input[1])
            init_value_info_index = list(model.graph.input).index(init_value_info)
            new_init_value_info = onnx.helper.make_tensor_value_info(
                    node.input[1],
                    onnx.TensorProto.FLOAT,
                    weights.shape)
            # update the graph
            model.graph.initializer.remove(init)
            model.graph.initializer.insert(init_index,new_init)
            model.graph.input.remove(init_value_info)
            model.graph.input.insert(init_value_info_index, new_init_value_info)
            # add an empty bias term
            new_bias = onnx.helper.make_tensor(
                name="_".join([node.input[1],"bias"]),
                data_type=init.data_type,
                dims=(weights.shape[0],),
                vals=np.zeros(weights.shape[0]).flatten().tolist())
            new_bias_value_info = onnx.helper.make_tensor_value_info(
                    new_bias.name,
                    onnx.TensorProto.FLOAT,
                    [weights.shape[0]])
            # update the graph
            model.graph.initializer.insert(-1,new_bias)
            model.graph.input.insert(-1,new_bias_value_info)
            # create a new matmul node
            new_node = onnx.helper.make_node(
                "Gemm",
                name=node.name,
                inputs=[*node.input, "_".join([node.input[1],"bias"])],
                outputs=node.output,
                alpha=1.0,
                beta=1.0,
                transA=0,
                transB=1
            )
            # remove old node and add new one
            model.graph.node.remove(node)
            model.graph.node.insert(index, new_node)
    # return the new model
    return model
HSQ79815 commented 2 years ago

The pattern of fuse_matmul_add_bias_into_gemm is that:

//   Z = MatMul(X, Y)
//   A = Z + Bias
// After:
//   A = Gemm(X, Y, Bias)
The shape of X is (M,K) , the shape of Y is (K,N) and the rank of Bias can be 1 or 2.

https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h#L11