onnx / optimizer

Actively maintained ONNX Optimizer
Apache License 2.0
642 stars 88 forks source link

fuse_matmul_add_bias_into_gemm not working with batch size #66

Closed erelon closed 1 year ago

erelon commented 2 years ago

Hi,

When using fuse_matmul_add_bias_into_gemm I expect that even with batch size the layers will fuse. Apparently, this is not supported. I can't see what is the reason for this. If there is a problem with more then one batch, the fuse can happen at least when the batch size dim is 1.

Here is the example code to create this issue (heavily based on #58):

    from onnx import helper
    from onnx import checker, helper, ModelProto, TensorProto, GraphProto, NodeProto, shape_inference
    import onnxoptimizer

    matmul = helper.make_node("MatMul", ["X", "Y"], ["Z"])
    add = helper.make_node("Add", ["Z", "B"], ["A"])
    graph = helper.make_graph(
        [matmul, add],
        "test",
        [helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 32, 10)),

         ],
        [helper.make_tensor_value_info("A", TensorProto.FLOAT, (1, 32, 16))],
        [helper.make_tensor("B", TensorProto.FLOAT, tuple([16]), np.ones([1, 16])),
         helper.make_tensor("Y", TensorProto.FLOAT, (1, 10, 16), np.ones([1, 10, 16])), ]
    )

    model = helper.make_model(graph)
    onnx.save(model, "gg.onnx")
    optimized_model = onnxoptimizer.optimize(
        model, passes=["fuse_matmul_add_bias_into_gemm"])
    onnx.save(optimized_model, "gg1.onnx")

    print(optimized_model)

    assert len(list(optimized_model.graph.node)) == 1
    assert optimized_model.graph.node[0].op_type == "Gemm"
HSQ79815 commented 2 years ago

@erelon The rank of Matmul input tensor must be 2 when using fuse_matmul_add_bias_into_gemm, you can find it in source https://github.com/onnx/optimizer/blob/master/onnxoptimizer/passes/fuse_matmul_add_bias_into_gemm.h#L60

HSQ79815 commented 2 years ago

In Gemm defination, The shape of A should be (M,K) or (K,M), and the shape of B should be (K,N) or (N,K).