onnx / optimizer

Actively maintained ONNX Optimizer
Apache License 2.0
647 stars 90 forks source link

Support fuse bn into ConvTranspose. #106

Open YuchiWen opened 2 years ago

daquexian commented 1 year ago

Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024

YuchiWen commented 1 year ago

Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024

@daquexian Done, please review.

huangzhicong3 commented 5 months ago

Hello, i try to used this commit to fuse the bn layer and convtranspose layer in my model and find some bugs: The error message is: passes/fuse_bn_into_conv.h:71: modify_conv: Assertion conv_W.sizes().size() > 2 && conv_W.sizes()[0] == C failed.

From the doc of onnx website (https://onnx.ai/onnx/operators/onnx__ConvTranspose.html), the shape of weight array of convtranspose is (Cin, Cout, K, K), which is different to normal Conv layer (Cout, Cin, K, K).

huangzhicong3 commented 5 months ago

Hi, i would like to share my codes for fusing convtranspose and bn. It has been tested on my own model. I hope it will help others who have the same issue.

import numpy as np
import onnx
import sclblonnx as so

model = onnx.load('../onnx/models/backbone_clean.onnx')

all_initializer = model.graph.initializer
all_node = model.graph.node
ConvTranspose_list = []
BatchNormalization_list = []
for i, node in enumerate(all_node):
    # search convtranspose and batchnormalization
    if node.op_type == "ConvTranspose":
        # print(i, node.name, node.op_type,  node.input, node.output)
        ConvTranspose_list.append(node)
    if node.op_type == "BatchNormalization":
        # print(i, node.name, node.op_type,  node.input, node.output)
        BatchNormalization_list.append(node)

valid_ConvTranspose_list = []
for node in ConvTranspose_list:
    output = node.output
    for bn_node in BatchNormalization_list:
        bn_inputs = bn_node.input
        if output[0] in bn_inputs:
            valid_ConvTranspose_list.append({"conv": node, "bn": bn_node})
            continue

# print(valid_ConvTranspose_list)
param_dict = {}
for node in valid_ConvTranspose_list:
    conv = node["conv"]
    bn = node["bn"]
    # find params
    param_name = list(conv.input) + list(bn.input)
    for i, initializer in enumerate(all_initializer):
        if initializer.name in param_name:
            param_dict[initializer.name] = onnx.numpy_helper.to_array(initializer)
# print(param_dict)
for node in valid_ConvTranspose_list:
    conv = node["conv"]
    bn = node["bn"]

    bn_eps = bn.attribute[0].f
    bn_mom = bn.attribute[1].f

    bn_w = param_dict[bn.input[1]]  # [Cout, ]
    bn_b = param_dict[bn.input[2]]  # [Cout, ]
    bn_mean = param_dict[bn.input[3]]  # [Cout, ]
    bn_var = param_dict[bn.input[4]]  # [Cout, ]

    conv_w = param_dict[conv.input[1]]  # [Cin, Cout, H, W]
    if len(conv.input) > 2:
        conv_b = param_dict[conv.input[2]]
    else:
        conv_b = np.zeros_like(bn_b)  # [Cout, ]
    conv_w_tran = conv_w.transpose(1, 0, 2, 3)

    Cout = conv_w_tran.shape[0]
    conv_w_reshape = conv_w_tran.reshape([Cout, -1])
    w_bn = np.diag(bn_w / (np.sqrt(bn_eps + bn_var)))
    new_conv_w = np.matmul(w_bn, conv_w_reshape).reshape(conv_w_tran.shape).transpose(1, 0, 2, 3)
    bn_b_tmp = bn_b - (np.multiply(bn_w, bn_mean) / (np.sqrt(bn_eps + bn_var)))
    new_conv_b = np.matmul(bn_w, conv_b) + bn_b_tmp

    new_node = onnx.helper.make_node(
        name=conv.name+'_bn',
        op_type="ConvTranspose",
        inputs=[conv.input[0], conv.name+'_bn.weights', conv.name+'_bn.bias'],
        outputs=[bn.output[0]],
        dilations=conv.attribute[0].ints,
        group=conv.attribute[1].i,
        kernel_shape=conv.attribute[2].ints,
        pads=conv.attribute[3].ints,
        strides=conv.attribute[4].ints
    )
    initializer_w = onnx.helper.make_tensor(
        name=conv.name+'_bn.weights',
        data_type=onnx.helper.TensorProto.DataType.FLOAT,
        dims=new_conv_w.shape,
        vals=new_conv_w.tobytes(),
        raw=True
    )
    initializer_b = onnx.helper.make_tensor(
        name=conv.name+'_bn.bias',
        data_type=onnx.helper.TensorProto.DataType.FLOAT,
        dims=new_conv_b.shape,
        vals=new_conv_b.tobytes(),
        raw=True
    )

    model.graph.initializer.append(initializer_w)
    model.graph.initializer.append(initializer_b)

    # insert node
    for i, node in enumerate(all_node):
        if conv.name == node.name:
            model.graph.node.insert(i, new_node)
            break
    # clean node
    model.graph.node.remove(conv)
    model.graph.node.remove(bn)

onnx.checker.check_model(model)
onnx.save(model, '../onnx/models/backbone_fuse.onnx')

graph = so.graph_from_file('../onnx/models/backbone_fuse.onnx')
graph = so.clean(graph)
so.check(graph)
so.graph_to_file(graph, '../onnx/models/backbone_fuse.onnx')