Open ausk opened 3 years ago
融合 Conv 和 BatchNorm 核心 pytorch 代码实现。
def fuse_conv_and_bn(conv, bn):
fused_conv = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
conv.dilation,
conv.groups,
bias=True,
padding_mode=conv.padding_mode
).requires_grad_(False).to(conv.weight.device)
if False:
# 法1:
# https://zhuanlan.zhihu.com/p/49329030
# https://github.com/qinjian623/pytorch_toys/blob/master/post_quant/fusion.py
mean = bn.running_mean
scale = bn.weight/torch.sqrt(bn.running_var + bn.eps)
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = mean.new_zeros(mean.shape)
w = conv.weight * scale.reshape([conv.out_channels, 1, 1, 1])
b = bn.bias + (b_conv - mean)*scale
fused_conv.weight = torch.nn.Parameter(w)
fused_conv.bias = torch.nn.Parameter(b)
else:
# 法2:
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
# https://github.com/ultralytics/yolov5/blob/ffef77124eb011d57597356dec2f6d96af211bed/utils/torch_utils.py#L172-L192
# prepare filters and spatial bias
scale_factor = bn.weight.div(torch.sqrt(bn.eps + bn.running_var)) # div(a,b) => a/b
w_conv = conv.weight.clone().view(conv.out_channels, -1)
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
fused_conv.weight.copy_(torch.mm(torch.diag(scale_factor), w_conv).view(fused_conv.weight.size()))
fused_conv.bias.copy_( bn.bias + scale_factor.mul(b_conv - bn.running_mean))
return fused_conv
更多参考:
Fuse Conv2d+BN