Open YuchiWen opened 2 years ago
Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024
@daquexian Done, please review.
Hello, i try to used this commit to fuse the bn layer and convtranspose layer in my model and find some bugs:
The error message is:
passes/fuse_bn_into_conv.h:71: modify_conv: Assertion conv_W.sizes().size() > 2 && conv_W.sizes()[0] == C
failed.
From the doc of onnx website (https://onnx.ai/onnx/operators/onnx__ConvTranspose.html), the shape of weight array of convtranspose is (Cin, Cout, K, K), which is different to normal Conv layer (Cout, Cin, K, K).
Hi, i would like to share my codes for fusing convtranspose and bn. It has been tested on my own model. I hope it will help others who have the same issue.
import numpy as np
import onnx
import sclblonnx as so
model = onnx.load('../onnx/models/backbone_clean.onnx')
all_initializer = model.graph.initializer
all_node = model.graph.node
ConvTranspose_list = []
BatchNormalization_list = []
for i, node in enumerate(all_node):
# search convtranspose and batchnormalization
if node.op_type == "ConvTranspose":
# print(i, node.name, node.op_type, node.input, node.output)
ConvTranspose_list.append(node)
if node.op_type == "BatchNormalization":
# print(i, node.name, node.op_type, node.input, node.output)
BatchNormalization_list.append(node)
valid_ConvTranspose_list = []
for node in ConvTranspose_list:
output = node.output
for bn_node in BatchNormalization_list:
bn_inputs = bn_node.input
if output[0] in bn_inputs:
valid_ConvTranspose_list.append({"conv": node, "bn": bn_node})
continue
# print(valid_ConvTranspose_list)
param_dict = {}
for node in valid_ConvTranspose_list:
conv = node["conv"]
bn = node["bn"]
# find params
param_name = list(conv.input) + list(bn.input)
for i, initializer in enumerate(all_initializer):
if initializer.name in param_name:
param_dict[initializer.name] = onnx.numpy_helper.to_array(initializer)
# print(param_dict)
for node in valid_ConvTranspose_list:
conv = node["conv"]
bn = node["bn"]
bn_eps = bn.attribute[0].f
bn_mom = bn.attribute[1].f
bn_w = param_dict[bn.input[1]] # [Cout, ]
bn_b = param_dict[bn.input[2]] # [Cout, ]
bn_mean = param_dict[bn.input[3]] # [Cout, ]
bn_var = param_dict[bn.input[4]] # [Cout, ]
conv_w = param_dict[conv.input[1]] # [Cin, Cout, H, W]
if len(conv.input) > 2:
conv_b = param_dict[conv.input[2]]
else:
conv_b = np.zeros_like(bn_b) # [Cout, ]
conv_w_tran = conv_w.transpose(1, 0, 2, 3)
Cout = conv_w_tran.shape[0]
conv_w_reshape = conv_w_tran.reshape([Cout, -1])
w_bn = np.diag(bn_w / (np.sqrt(bn_eps + bn_var)))
new_conv_w = np.matmul(w_bn, conv_w_reshape).reshape(conv_w_tran.shape).transpose(1, 0, 2, 3)
bn_b_tmp = bn_b - (np.multiply(bn_w, bn_mean) / (np.sqrt(bn_eps + bn_var)))
new_conv_b = np.matmul(bn_w, conv_b) + bn_b_tmp
new_node = onnx.helper.make_node(
name=conv.name+'_bn',
op_type="ConvTranspose",
inputs=[conv.input[0], conv.name+'_bn.weights', conv.name+'_bn.bias'],
outputs=[bn.output[0]],
dilations=conv.attribute[0].ints,
group=conv.attribute[1].i,
kernel_shape=conv.attribute[2].ints,
pads=conv.attribute[3].ints,
strides=conv.attribute[4].ints
)
initializer_w = onnx.helper.make_tensor(
name=conv.name+'_bn.weights',
data_type=onnx.helper.TensorProto.DataType.FLOAT,
dims=new_conv_w.shape,
vals=new_conv_w.tobytes(),
raw=True
)
initializer_b = onnx.helper.make_tensor(
name=conv.name+'_bn.bias',
data_type=onnx.helper.TensorProto.DataType.FLOAT,
dims=new_conv_b.shape,
vals=new_conv_b.tobytes(),
raw=True
)
model.graph.initializer.append(initializer_w)
model.graph.initializer.append(initializer_b)
# insert node
for i, node in enumerate(all_node):
if conv.name == node.name:
model.graph.node.insert(i, new_node)
break
# clean node
model.graph.node.remove(conv)
model.graph.node.remove(bn)
onnx.checker.check_model(model)
onnx.save(model, '../onnx/models/backbone_fuse.onnx')
graph = so.graph_from_file('../onnx/models/backbone_fuse.onnx')
graph = so.clean(graph)
so.check(graph)
so.graph_to_file(graph, '../onnx/models/backbone_fuse.onnx')
Sorry for the late response. Could you please add some tests for the fusion? You can follow the conv-bn fusion https://github.com/onnx/optimizer/blob/master/onnxoptimizer/test/optimizer_test.py#L3024