DingXiaoH / ACNet

ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks
MIT License
831 stars 133 forks source link

Plug-in version implementation #40

Closed zjykzj closed 3 years ago

zjykzj commented 3 years ago

@DingXiaoH nice work!!! I implemented a plug-in implementation about ACBlock. I hope it will help you and others

realize

The complete implementation is divided into two files:

  1. asymmetric_convolution_block.py
  2. conv_helper.py

file asymmetric_convolution_block.py realized ACBlock

# -*- coding: utf-8 -*-

"""
@date: 2021/2/1 下午7:10
@file: asymmetric_convolution_block.py
@author: zj
@description: 
"""

import torch.nn as nn

class AsymmetricConvolutionBlock(nn.Module):
    """
    参考[ACNet/acnet/acb.py](https://github.com/DingXiaoH/ACNet/blob/master/acnet/acb.py)实现
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                 padding_mode='zeros', use_affine=True, reduce_gamma=False, use_last_bn=False, gamma_init=None):
        super(AsymmetricConvolutionBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode

        self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                     kernel_size=(kernel_size, kernel_size), stride=stride,
                                     padding=padding, dilation=dilation, groups=groups, bias=False,
                                     padding_mode=padding_mode)
        self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)

        center_offset_from_origin_border = padding - kernel_size // 2
        ver_pad_or_crop = (padding, center_offset_from_origin_border)
        hor_pad_or_crop = (center_offset_from_origin_border, padding)
        if center_offset_from_origin_border >= 0:
            self.ver_conv_crop_layer = nn.Identity()
            ver_conv_padding = ver_pad_or_crop
            self.hor_conv_crop_layer = nn.Identity()
            hor_conv_padding = hor_pad_or_crop
        else:
            self.ver_conv_crop_layer = CropLayer(crop_set=ver_pad_or_crop)
            ver_conv_padding = (0, 0)
            self.hor_conv_crop_layer = CropLayer(crop_set=hor_pad_or_crop)
            hor_conv_padding = (0, 0)
        self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
                                  stride=stride,
                                  padding=ver_conv_padding, dilation=dilation, groups=groups, bias=False,
                                  padding_mode=padding_mode)

        self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
                                  stride=stride,
                                  padding=hor_conv_padding, dilation=dilation, groups=groups, bias=False,
                                  padding_mode=padding_mode)
        self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
        self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)

        if reduce_gamma:
            assert not use_last_bn
            self.init_gamma(1.0 / 3)

        if use_last_bn:
            assert not reduce_gamma
            self.last_bn = nn.BatchNorm2d(num_features=out_channels, affine=True)

        if gamma_init is not None:
            assert not reduce_gamma
            self.init_gamma(gamma_init)

    def init_gamma(self, gamma_value):
        nn.init.constant_(self.square_bn.weight, gamma_value)
        nn.init.constant_(self.ver_bn.weight, gamma_value)
        nn.init.constant_(self.hor_bn.weight, gamma_value)
        print('init gamma of square, ver and hor as ', gamma_value)

    def single_init(self):
        nn.init.constant_(self.square_bn.weight, 1.0)
        nn.init.constant_(self.ver_bn.weight, 0.0)
        nn.init.constant_(self.hor_bn.weight, 0.0)
        print('init gamma of square as 1, ver and hor as 0')

    def forward(self, input):
        square_outputs = self.square_conv(input)
        square_outputs = self.square_bn(square_outputs)
        vertical_outputs = self.ver_conv_crop_layer(input)
        vertical_outputs = self.ver_conv(vertical_outputs)
        vertical_outputs = self.ver_bn(vertical_outputs)
        horizontal_outputs = self.hor_conv_crop_layer(input)
        horizontal_outputs = self.hor_conv(horizontal_outputs)
        horizontal_outputs = self.hor_bn(horizontal_outputs)
        result = square_outputs + vertical_outputs + horizontal_outputs
        if hasattr(self, 'last_bn'):
            return self.last_bn(result)
        return result

class CropLayer(nn.Module):

    #   E.g., (-1, 0) means this layer should crop the first and last rows of the feature map. And (0, -1) crops the first and last columns
    def __init__(self, crop_set):
        super(CropLayer, self).__init__()
        self.rows_to_crop = - crop_set[0]
        self.cols_to_crop = - crop_set[1]
        assert self.rows_to_crop >= 0
        assert self.cols_to_crop >= 0

    def forward(self, input):
        if self.rows_to_crop == 0 and self.cols_to_crop == 0:
            return input
        elif self.rows_to_crop > 0 and self.cols_to_crop == 0:
            return input[:, :, self.rows_to_crop:-self.rows_to_crop, :]
        elif self.rows_to_crop == 0 and self.cols_to_crop > 0:
            return input[:, :, :, self.cols_to_crop:-self.cols_to_crop]
        else:
            return input[:, :, self.rows_to_crop:-self.rows_to_crop, self.cols_to_crop:-self.cols_to_crop]

conv_helper.py used to insert and fuse ACBlock

# -*- coding: utf-8 -*-

"""
@date: 2020/12/4 下午4:11
@file: act_helper.py
@author: zj
@description: 
"""

import torch
import torch.nn as nn
from .layers.asymmetric_convolution_block import AsymmetricConvolutionBlock

def get_conv(cfg):
    """
    Args:
        cfg (CfgNode): model building configs, details are in the comments of
            the config file.
    Returns:
        nn.Module: the conv layer.
    """
    conv_type = cfg.MODEL.CONV.TYPE
    if conv_type == "Conv2d":
        return nn.Conv2d
    elif conv_type == "Conv3d":
        return nn.Conv3d
    else:
        raise NotImplementedError(
            "Conv type {} is not supported".format(conv_type)
        )

def insert_acblock(model: nn.Module):
    items = list(model.named_children())
    idx = 0
    while idx < len(items):
        name, module = items[idx]
        if isinstance(module, nn.Conv2d) and module.kernel_size[0] > 1:
            # 将标准卷积替换为ACBlock
            in_channels = module.in_channels
            out_channels = module.out_channels
            kernel_size = module.kernel_size
            stride = module.stride
            padding = module.padding
            dilation = module.dilation
            groups = module.groups
            padding_mode = module.padding_mode

            acblock = AsymmetricConvolutionBlock(in_channels,
                                                 out_channels,
                                                 kernel_size[0],
                                                 stride,
                                                 padding=padding[0],
                                                 padding_mode=padding_mode,
                                                 dilation=dilation,
                                                 groups=groups)
            model.add_module(name, acblock)
            # 如果conv层之后跟随着BN层,那么删除该BN层
            # 参考[About BN layer #35](https://github.com/DingXiaoH/ACNet/issues/35)
            if (idx + 1) < len(items) and isinstance(items[idx + 1][1], nn.BatchNorm2d):
                new_layer = nn.Identity()
                model.add_module(items[idx + 1][0], new_layer)
        else:
            insert_acblock(module)
        idx += 1

def fuse_acblock(model: nn.Module, eps=1e-5):
    for name, module in model.named_children():
        if isinstance(module, AsymmetricConvolutionBlock):
            # 将ACBlock替换为标准卷积
            # 获取NxN卷积的权重以及对应BN的权重、偏置、运行时均值、运行时方差
            square_conv_weight = module.square_conv.weight
            square_bn_weight = module.square_bn.weight
            square_bn_bias = module.square_bn.bias
            square_bn_running_mean = module.square_bn.running_mean
            square_bn_running_std = torch.sqrt(module.square_bn.running_var + eps)
            # 获取Nx1卷积的权重以及对应BN的权重、偏置、运行时均值、运行时方差
            vertical_conv_weight = module.ver_conv.weight
            vertical_bn_weight = module.ver_bn.weight
            vertical_bn_bias = module.ver_bn.bias
            vertical_bn_running_mean = module.ver_bn.running_mean
            vertical_bn_running_std = torch.sqrt(module.ver_bn.running_var + eps)
            # 获取1xN卷积的权重以及对应BN的权重、偏置、运行时均值、运行时方差
            horizontal_conv_weight = module.hor_conv.weight
            horizontal_bn_weight = module.hor_bn.weight
            horizontal_bn_bias = module.hor_bn.bias
            horizontal_bn_running_mean = module.hor_bn.running_mean
            horizontal_bn_running_std = torch.sqrt(module.hor_bn.running_var + eps)
            # 计算偏差
            fused_bias = square_bn_bias + vertical_bn_bias + horizontal_bn_bias \
                         - square_bn_running_mean * square_bn_weight / square_bn_running_std \
                         - vertical_bn_running_mean * vertical_bn_weight / vertical_bn_running_std \
                         - horizontal_bn_running_mean * horizontal_bn_weight / horizontal_bn_running_std
            # 计算权重
            fused_kernel = _fuse_kernel(square_conv_weight, square_bn_weight, square_bn_running_std)
            _add_to_square_kernel(fused_kernel,
                                  _fuse_kernel(vertical_conv_weight, vertical_bn_weight, vertical_bn_running_std))
            _add_to_square_kernel(fused_kernel,
                                  _fuse_kernel(horizontal_conv_weight, horizontal_bn_weight, horizontal_bn_running_std))
            # 新建标准卷积,赋值权重和偏差后重新插入模型
            fused_conv = nn.Conv2d(module.in_channels,
                                   module.out_channels,
                                   module.kernel_size,
                                   stride=module.stride,
                                   padding=module.padding,
                                   dilation=module.dilation,
                                   groups=module.groups,
                                   padding_mode=module.padding_mode
                                   )
            fused_conv.weight = nn.Parameter(fused_kernel)
            fused_conv.bias = nn.Parameter(fused_bias)
            model.add_module(name, fused_conv)
        else:
            fuse_acblock(module, eps=eps)

def _fuse_kernel(kernel, gamma, std):
    b_gamma = torch.reshape(gamma, (kernel.shape[0], 1, 1, 1))
    b_gamma = b_gamma.repeat(1, kernel.shape[1], kernel.shape[2], kernel.shape[3])
    b_std = torch.reshape(std, (kernel.shape[0], 1, 1, 1))
    b_std = b_std.repeat(1, kernel.shape[1], kernel.shape[2], kernel.shape[3])
    return kernel * b_gamma / b_std

def _add_to_square_kernel(square_kernel, asym_kernel):
    asym_h = asym_kernel.shape[2]
    asym_w = asym_kernel.shape[3]
    square_h = square_kernel.shape[2]
    square_w = square_kernel.shape[3]
    square_kernel[:, :, square_h // 2 - asym_h // 2: square_h // 2 - asym_h // 2 + asym_h,
    square_w // 2 - asym_w // 2: square_w // 2 - asym_w // 2 + asym_w] += asym_kernel

test

# -*- coding: utf-8 -*-

"""
@date: 2021/2/1 下午8:01
@file: test_asymmetric_convolution_block.py
@author: zj
@description: 
"""

import torch
import torch.nn as nn
from torchvision.models import resnet50

from zcls.model.layers.asymmetric_convolution_block import AsymmetricConvolutionBlock
from zcls.model.conv_helper import insert_acblock, fuse_acblock

def test_asymmetric_convolution_block():
    in_channels = 32
    out_channels = 64
    dilation = 1
    groups = 1

    # inputs == outputs
    kernel_size = 3
    stride = 1
    padding = 1
    acblock = AsymmetricConvolutionBlock(in_channels,
                                         out_channels,
                                         kernel_size,
                                         stride=stride,
                                         padding=padding,
                                         dilation=dilation,
                                         groups=groups)

    data = torch.randn(1, in_channels, 56, 56)
    outputs = acblock.forward(data)

    _, _, h, w = data.shape[:4]
    _, _, h2, w2 = outputs.shape[:4]
    assert h == h2 and w == w2

    # 下采样
    kernel_size = 3
    stride = 2
    padding = 1
    acblock = AsymmetricConvolutionBlock(in_channels,
                                         out_channels,
                                         kernel_size,
                                         stride=stride,
                                         padding=padding,
                                         dilation=dilation,
                                         groups=groups)

    data = torch.randn(1, in_channels, 56, 56)
    outputs = acblock.forward(data)

    _, _, h, w = data.shape[:4]
    _, _, h2, w2 = outputs.shape[:4]
    assert h / 2 == h2 and w / 2 == w2

    # 下采样 + 分组卷积
    kernel_size = 3
    stride = 2
    padding = 1
    groups = 8
    acblock = AsymmetricConvolutionBlock(in_channels,
                                         out_channels,
                                         kernel_size,
                                         stride=stride,
                                         padding=padding,
                                         dilation=dilation,
                                         groups=groups)

    data = torch.randn(1, in_channels, 56, 56)
    outputs = acblock.forward(data)

    _, _, h, w = data.shape[:4]
    _, _, h2, w2 = outputs.shape[:4]
    assert h / 2 == h2 and w / 2 == w2

def test_acb_helper():
    in_channels = 32
    out_channels = 64
    dilation = 1

    # 下采样 + 分组卷积
    kernel_size = 3
    stride = 2
    padding = 1
    groups = 8

    model = nn.Sequential(
        nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size,
                      stride=stride, padding=padding, dilation=dilation, groups=groups),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )
    model.eval()
    print(model)

    data = torch.randn(1, in_channels, 56, 56)
    insert_acblock(model)
    model.eval()
    train_outputs = model(data)
    print(model)

    fuse_acblock(model, eps=1e-5)
    model.eval()
    eval_outputs = model(data)
    print(model)

    print(torch.sum((train_outputs - eval_outputs) ** 2))
    print(torch.allclose(train_outputs, eval_outputs, atol=1e-6))
    assert torch.allclose(train_outputs, eval_outputs, atol=1e-6)

def test_resnet50_acb():
    model = resnet50()
    model.eval()
    # print(model)

    data = torch.randn(1, 3, 224, 224)
    insert_acblock(model)
    model.eval()
    train_outputs = model(data)
    # print(model)

    fuse_acblock(model, eps=1e-5)
    model.eval()
    eval_outputs = model(data)
    # print(model)

    print(torch.sum((train_outputs - eval_outputs) ** 2))
    print(torch.allclose(train_outputs, eval_outputs, atol=1e-5))
    assert torch.allclose(train_outputs, eval_outputs, atol=1e-5)

if __name__ == '__main__':
    print('*' * 10)
    test_asymmetric_convolution_block()
    print('*' * 10)
    test_acb_helper()
    print('*' * 10)
    test_resnet50_acb()
############################### output
**********
**********
Sequential(
  (0): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=8)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
)
Sequential(
  (0): Sequential(
    (0): AsymmetricConvolutionBlock(
      (square_conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=8, bias=False)
      (square_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (ver_conv_crop_layer): Identity()
      (hor_conv_crop_layer): Identity()
      (ver_conv): Conv2d(32, 64, kernel_size=(3, 1), stride=(2, 2), padding=(1, 0), groups=8, bias=False)
      (hor_conv): Conv2d(32, 64, kernel_size=(1, 3), stride=(2, 2), padding=(0, 1), groups=8, bias=False)
      (ver_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hor_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Identity()
    (2): ReLU(inplace=True)
  )
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
)
Sequential(
  (0): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=8)
    (1): Identity()
    (2): ReLU(inplace=True)
  )
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
)
tensor(5.4800e-10, grad_fn=<SumBackward0>)
True
**********
tensor(1.0568e-08, grad_fn=<SumBackward0>)
True

Process finished with exit code 0

how to use

You can use other models as usual. if you want to add ACBlock, do it like this:

from zcls.model.conv_helper import insert_acblock
。。。
。。。
if cfg.MODEL.CONV.ACBLOCK is True:
    insert_acblock(model)                             # -----------------> HERE

model = model.to(device=device)
if du.get_world_size() > 1:
    model = DDP(model, device_ids=[device], output_device=device, find_unused_parameters=True)

return model

Then normal training and model parameter preservation are carried out, if you want to fuse ACBlock, you can use func fuse_acblock

The complete implementation can be referred to ZJCV/ZCls

DingXiaoH commented 3 years ago

Thanks! I will release ACNet v2 in about a month and update the whole repo. At that time I will add a link to your repo.

zjykzj commented 3 years ago

Thanks! I will release ACNet v2 in about a month and update the whole repo. At that time I will add a link to your repo.

my pleasure!!! by the way, i also reproduce RepVGG, Hope to help you and others DingXiaoH/RepVGG#24

hi @DingXiaoH, This is a simple and intuitive implementation !!! I implemented a plug-in version about RepVGGBlock. I hope it will help you and others

This plug-in version implements the following functions:

    The training model and the test model are separated;
    You can apply RepVGGBlock to other models;
    You can use RepVGGBlock and ACBlock together in training, no matter which order。

。。。
。。。