ausk / diary-xiu

0 stars 0 forks source link

Math #6

Open ausk opened 3 years ago

ausk commented 3 years ago

Fuse Conv2d+BN

conv: $Y = W*X + B$

bn: $Y = Gamma * \frac{X-RunMean}{ \sqrt[2]{RunStd + eps } } + Beta$

conv + bn:

$$\begin{aligned}
Y &= W*X + B \\
Y &= Gamma * \frac{X-RunMean}{ \sqrt[2]{RunStd + eps } } + Beta \\
& = Gamma * \frac{ (W*X+B) -RunMean}{ \sqrt{RunStd + eps } } + Beta \\
&= \frac{ Gamma * W}{\sqrt{RunStd + eps}} X + \left( \frac{ Gamma * (B - RunMean)}{\sqrt{RunStd + eps}} + Beta \right) \\

\end{aligned}$$

conv_bn:

$$\begin{aligned}

\alpha &= \frac{Gamma}{\sqrt{RunStd + eps}}  \\

W^{'} &= \alpha * W  \\

B^{'} &= \alpha * (B - RunMean) + Beta\\

\end{aligned}$$

image

v2_to_v1_names = {
    'bn.weight': 'gamma',
    'bn.bias': 'beta',
    'bn.running_mean': 'running_mean',
    'bn.running_var': 'running_var',
    'bn.num_batches_tracked': 'num_batches_tracked',
}
ausk commented 3 years ago

Pytorch fuse convolution and batchnorm layers

融合 Conv 和 BatchNorm 核心 pytorch 代码实现。

def fuse_conv_and_bn(conv, bn):
    fused_conv = torch.nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        conv.kernel_size,
        conv.stride,
        conv.padding,
        conv.dilation,
        conv.groups,
        bias=True,
        padding_mode=conv.padding_mode
    ).requires_grad_(False).to(conv.weight.device)

    if False:
        # 法1: 
        # https://zhuanlan.zhihu.com/p/49329030
        # https://github.com/qinjian623/pytorch_toys/blob/master/post_quant/fusion.py
        mean = bn.running_mean
        scale = bn.weight/torch.sqrt(bn.running_var + bn.eps)

        if conv.bias is not None:
            b_conv = conv.bias
        else:
            b_conv = mean.new_zeros(mean.shape)

        w = conv.weight * scale.reshape([conv.out_channels, 1, 1, 1])
        b = bn.bias + (b_conv - mean)*scale

        fused_conv.weight = torch.nn.Parameter(w)
        fused_conv.bias = torch.nn.Parameter(b)
    else:
        # 法2:
        # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
        # https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
        # https://github.com/ultralytics/yolov5/blob/ffef77124eb011d57597356dec2f6d96af211bed/utils/torch_utils.py#L172-L192

        # prepare filters and spatial bias
        scale_factor = bn.weight.div(torch.sqrt(bn.eps + bn.running_var)) # div(a,b) => a/b
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
        fused_conv.weight.copy_(torch.mm(torch.diag(scale_factor), w_conv).view(fused_conv.weight.size()))
        fused_conv.bias.copy_( bn.bias + scale_factor.mul(b_conv - bn.running_mean))

    return fused_conv

更多参考:

  1. Keras-inference-time-optimizer
  2. Pytorch BN Fuse
  3. Docs > Module code > torch > torch.quantization > torch.quantization.fuse_modules