zhanghang1989 / ResNeSt

ResNeSt: Split-Attention Networks
https://arxiv.org/abs/2004.08955
Apache License 2.0
3.24k stars 496 forks source link

reimplement of split_attention_conv2d and why don't want to add BN2/ReLU in Bottleneck? #128

Open zjykzj opened 3 years ago

zjykzj commented 3 years ago

hi @zhanghang1989 ,First of all, thank you very much for providing such an imaginative model

I refer to the source code implementation of ResNetSt and reproduce a new implementation of SplitAttentionConv2d. The implementation architecture may be clearer

# -*- coding: utf-8 -*-

"""
@date: 2021/1/4 上午11:32
@file: split_attention_conv2d.py
@author: zj
@description: 
"""
from abc import ABC

import torch

import torch.nn as nn

from ..init_helper import init_weights

class SplitAttentionConv2d(nn.Module, ABC):
    """
    ResNetSt的SplitAttention实现,参考:
    1. https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnest.py
    2. https://github.com/zhanghang1989/ResNeSt/blob/73b43ba63d1034dbf3e96b3010a8f2eb4cc3854f/resnest/torch/splat.py
    部分参考./selective_kernel_conv2d.py实现
    """

    def __init__(self,
                 # 输入通道数
                 in_channels,
                 # 输出通道数
                 out_channels,
                 # 每个group中的分离数
                 radix=2,
                 # cardinality
                 groups=1,
                 # 中间层衰减率
                 reduction_rate=4,
                 # 默认中间层最小通道数
                 default_channels: int = 32,
                 # 维度
                 dimension: int = 2
                 ):
        super(SplitAttentionConv2d, self).__init__()

        # split
        self.split = nn.Sequential(
            nn.Conv2d(in_channels, out_channels * radix, kernel_size=3, stride=1, padding=1, bias=False,
                      groups=groups * radix),
            nn.BatchNorm2d(out_channels * radix),
            nn.ReLU(inplace=True)
        )
        # fuse
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        inner_channels = max(out_channels // reduction_rate, default_channels)
        self.compact = nn.Sequential(
            nn.Conv2d(out_channels, inner_channels, kernel_size=1, stride=1, padding=0, bias=False,
                      groups=groups),
            nn.BatchNorm2d(inner_channels),
            nn.ReLU(inplace=True)
        )
        # select
        self.select = nn.Conv2d(inner_channels, out_channels * radix, kernel_size=1, stride=1, bias=False,
                                groups=groups)
        self.softmax = nn.Softmax(dim=0)
        self.dimension = dimension
        self.out_channels = out_channels
        self.radix = radix

        init_weights(self.modules())

    def forward(self, x):
        # N, C, H, W = x.shape[:4]
        # split
        out = self.split(x)
        split_out = torch.stack(torch.split(out, self.out_channels, dim=1))
        # fuse
        u = torch.sum(split_out, dim=0)
        s = self.pool(u)
        z = self.compact(s)
        # select
        c = self.select(z)
        split_c = torch.stack(torch.split(c, self.out_channels, dim=1))
        softmax_c = self.softmax(split_c)

        v = torch.sum(split_out.mul(softmax_c), dim=0)
        return v.contiguous()

and one of my question is why there is no need to add bn2/relu in Bottleneck when radix>0,Is it obtained through experiments?

FrancescoSaverioZuppichini commented 3 years ago

Same question, posting my implementation for completeness:

class SplitAtt(nn.Module):
    def __init__(self, in_features: int, features: int, radix: int, groups: int):
        """Implementation of Split Attention proposed in `"ResNeSt: Split-Attention Networks" <https://arxiv.org/abs/2004.08955>`_
        Grouped convolution have been proved to be impirically better (ResNetXt). The main idea is to apply an attention group-wise. 
        Einops is used to improve the readibility of this module
        Args:
            in_features (int): number of input features
            features (int): attention's features
            radix (int): number of subgroups (`radix`) in the groups
            groups (int): number of groups, each group contains `radix` subgroups
        """
        super().__init__()
        self.radix, self.groups = radix, groups
        self.att = nn.Sequential(
            # this produces U^{/hat}
            Reduce('b (r k c) h w -> b (k c) h w',
                   reduction='mean', r=radix, k=groups),
            # eq 1
            nn.AdaptiveAvgPool2d(1),
            # the two following conv layers are G in the paper
            ConvBnAct(in_features, features, kernel_size=1,
                      groups=groups, activation=ReLUInPlace, bias=True),
            nn.Conv2d(features, in_features * radix,
                      kernel_size=1, groups=groups),
            Rearrange('b (r k c) h w -> b r (k c) h w', r=radix, k=groups),
            nn.Softmax(dim=1) if radix > 1 else nn.Sigmoid(),
            Rearrange('b r (k c) h w -> b (r k c) h w', r=radix, k=groups)
        )

    def forward(self, x: Tensor) -> Tensor:
        att = self.att(x)
        # eq 2, scale using att and sum-up over the radix axis
        x *= att 
        x = reduce(x, 'b (r k c) h w -> b (k c) h w',
                   reduction='mean', r=self.radix, k=self.groups)
        return x

btw the bias in the first conv is useless but it is present in the original implementation, I guess it is an error

[Edit] After thinking about it, I think it makes sense because when radix > 1 softmax is applied (in rSoftmax) while when radix=0 sigmoid is used making it the same as SE. But there shouldn't be a batchnorm and a ReLU

FrancescoSaverioZuppichini commented 3 years ago

Hi @zhanghang1989, may I ask you a question? I don't really get it why bn + relu is applied after the 1 conv in the attention module but no bn + relu is applied in the second one. Thanks :)

zhanghang1989 commented 3 years ago

Hi @FrancescoSaverioZuppichini , the bn+relu is applied to the first conv, because it adds non-linearity between two convs (otherwise it is equivalent to a single one). There is no bn+relu for the second conv, because the softmax is a kind of non-linearity or activation function.

zjykzj commented 3 years ago

hi @zhanghang1989 @FrancescoSaverioZuppichini. i modified code to match the same results compare with origin code.

code

# -*- coding: utf-8 -*-

"""
@date: 2021/1/4 上午11:32
@file: split_attention_conv2d.py
@author: zj
@description: 
"""
from abc import ABC

import torch

import torch.nn as nn
import torch.nn.functional as F

from ..init_helper import init_weights

class SplitAttentionConv2d(nn.Module, ABC):
    """
    ResNetSt的SplitAttention实现,参考:
    1. https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnest.py
    2. https://github.com/zhanghang1989/ResNeSt/blob/73b43ba63d1034dbf3e96b3010a8f2eb4cc3854f/resnest/torch/splat.py
    部分参考./selective_kernel_conv2d.py实现
    """

    def __init__(self,
                 # 输入通道数
                 in_channels,
                 # 输出通道数
                 out_channels,
                 # 每个group中的分离数
                 radix=2,
                 # cardinality
                 groups=1,
                 # 中间层衰减率
                 reduction_rate=4,
                 # 默认中间层最小通道数
                 default_channels: int = 32,
                 # 维度
                 dimension: int = 2
                 ):
        super(SplitAttentionConv2d, self).__init__()

        # split
        self.split = nn.Sequential(
            nn.Conv2d(in_channels, out_channels * radix, kernel_size=3, stride=1, padding=1, bias=False,
                      groups=groups * radix),
            nn.BatchNorm2d(out_channels * radix),
            nn.ReLU(inplace=True)
        )
        # self.conv1 = nn.Conv2d(in_channels, out_channels * radix, kernel_size=3, stride=1, padding=1, bias=False,
        #                        groups=groups * radix)
        # self.bn1 = nn.BatchNorm2d(out_channels * radix)
        # self.relu = nn.ReLU(inplace=True)
        # fuse
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        inner_channels = max(out_channels // reduction_rate, default_channels)
        self.compact = nn.Sequential(
            nn.Conv2d(out_channels, inner_channels, kernel_size=1, stride=1, padding=0, bias=False,
                      groups=groups),
            nn.BatchNorm2d(inner_channels),
            nn.ReLU(inplace=True)
        )
        # select
        self.select = nn.Conv2d(inner_channels, out_channels * radix, kernel_size=1, stride=1,
                                groups=groups)
        # self.softmax = nn.Softmax(dim=0)
        self.rsoftmax = rSoftMax(radix, groups)
        self.dimension = dimension
        self.out_channels = out_channels
        self.radix = radix
        self.groups = groups

        init_weights(self.modules())

    def forward(self, x):
        N, C, H, W = x.shape[:4]
        # split
        out = self.split(x)
        # out = self.conv1(x)
        # out = self.bn1(out)
        # out = self.relu(out)
        split_out = torch.stack(torch.split(out, self.out_channels, dim=1))
        # fuse
        u = torch.sum(split_out, dim=0)
        s = self.pool(u)
        z = self.compact(s)
        # select
        c = self.select(z)
        softmax_c = self.rsoftmax(c).view(N, -1, 1, 1)

        attens = torch.split(softmax_c, self.out_channels, dim=1)
        v = sum([att * split for (att, split) in zip(attens, split_out)])
        return v.contiguous()
        # split_c = torch.stack(torch.split(c, self.out_channels, dim=1))
        # softmax_c = self.softmax(split_c)
        #
        # v = torch.sum(split_out.mul(softmax_c), dim=0)
        # return v.contiguous()

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x

test

# -*- coding: utf-8 -*-

"""
@date: 2021/1/1 下午9:19
@file: test_selective_kernel_conv2d.py
@author: zj
@description: 
"""

import numpy as np
import torch
import torch.nn as nn
from resnest.torch.splat import SplAtConv2d

from zcls.model.layers.split_attention_conv2d import SplitAttentionConv2d

def init_weights(modules):
    for m in modules:
        if isinstance(m, nn.Conv2d):
            nn.init.constant_(m.weight, 0.01)
            # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            # nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.weight, 0.01)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

def get_custom_res():
    # 保证训练时获取的随机数都是一样的
    init_seed = 1
    torch.manual_seed(init_seed)
    torch.cuda.manual_seed(init_seed)
    np.random.seed(init_seed)

    num = 3
    in_channels = 64
    out_channels = 128
    data = torch.randn(num, in_channels, 56, 56)

    # custom
    model = SplitAttentionConv2d(in_channels,
                                 out_channels,
                                 radix=2,
                                 groups=32,
                                 reduction_rate=16
                                 )
    init_weights(model.modules())
    print(model)
    outputs_c = model(data)
    print(outputs_c.shape)

    return outputs_c

def get_official_res():
    # 保证训练时获取的随机数都是一样的
    init_seed = 1
    torch.manual_seed(init_seed)
    torch.cuda.manual_seed(init_seed)
    np.random.seed(init_seed)

    num = 3
    in_channels = 64
    out_channels = 128
    data = torch.randn(num, in_channels, 56, 56)

    # official
    model = SplAtConv2d(in_channels,
                        out_channels,
                        3,
                        norm_layer=nn.BatchNorm2d,
                        bias=False,
                        padding=1,
                        radix=2,
                        groups=32,
                        reduction_factor=16
                        )
    init_weights(model.modules())
    print(model)
    outputs_o = model(data)
    print(outputs_o.shape)

    return outputs_o

def compare():
    outputs_c = get_custom_res()
    outputs_o = get_official_res()

    res = torch.allclose(outputs_c, outputs_o)
    print(res)

def test_split_attention_conv2d():
    num = 3
    in_channels = 64
    out_channels = 128
    data = torch.randn(num, in_channels, 56, 56)

    # 不进行分组
    model = SplitAttentionConv2d(in_channels,
                                 out_channels,
                                 radix=2,
                                 groups=1,
                                 reduction_rate=4
                                 )
    print(model)
    outputs = model(data)
    print(outputs.shape)

    assert outputs.shape == (num, out_channels, 56, 56)

    # 不进行radix
    model = SplitAttentionConv2d(in_channels,
                                 out_channels,
                                 radix=1,
                                 groups=1,
                                 reduction_rate=16
                                 )
    print(model)
    outputs = model(data)
    print(outputs.shape)

    assert outputs.shape == (num, out_channels, 56, 56)

    # 同时实现radix和group
    model = SplitAttentionConv2d(in_channels,
                                 out_channels,
                                 radix=2,
                                 groups=32,
                                 reduction_rate=16
                                 )
    print(model)
    outputs = model(data)
    print(outputs.shape)

    assert outputs.shape == (num, out_channels, 56, 56)

if __name__ == '__main__':
    compare()
    test_split_attention_conv2d()

result

SplitAttentionConv2d(
  (split): Sequential(
    (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (compact): Sequential(
    (0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), groups=32, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (select): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1), groups=32)
  (rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
SplAtConv2d(
  (conv): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
  (bn0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (fc1): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), groups=32)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1), groups=32)
  (rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
True
SplitAttentionConv2d(
  (split): Sequential(
    (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (compact): Sequential(
    (0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (select): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1))
  (rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
SplitAttentionConv2d(
  (split): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (compact): Sequential(
    (0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (select): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
  (rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
SplitAttentionConv2d(
  (split): Sequential(
    (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (compact): Sequential(
    (0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), groups=32, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (select): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1), groups=32)
  (rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
FrancescoSaverioZuppichini commented 3 years ago

Nice! I have also updated mine, I have explicitly rearrange the input tensor to have the radix axis after the batch. I have tried to clone the pretrained weights but I have terrible results, so probably it is not correct, but I cannot see the error. I don't know why in the first try I used mean instead of sum :)


class SplitAtt(nn.Module):
    def __init__(self, in_features: int, features: int, radix: int, groups: int):
        """Implementation of Split Attention proposed in `"ResNeSt: Split-Attention Networks" <https://arxiv.org/abs/2004.08955>`_

        Grouped convolution have been proved to be impirically better (ResNetXt). The main idea is to apply an attention group-wise. 

        Einops is used to improve the readibility of this module

        Args:
            in_features (int): number of input features
            features (int): attention's features
            radix (int): number of subgroups (`radix`) in the groups
            groups (int): number of groups, each group contains `radix` subgroups
        """
        super().__init__()
        self.radix, self.groups = radix, groups
        self.att = nn.Sequential(
            # this produces U^{/hat}
            Reduce('b r (k c) h w-> b (k c) h w',
                   reduction='sum', r=radix, k=groups),
            # eq 1
            nn.AdaptiveAvgPool2d(1),
            # the two following conv layers are G in the paper
            ConvBnAct(in_features, features, kernel_size=1,
                      groups=groups, activation=ReLUInPlace, bias=True),
            nn.Conv2d(features, in_features * radix,
                      kernel_size=1, groups=groups),
            Rearrange('b (r k c) h w -> b r k c h w', r=radix, k=groups),
            nn.Softmax(dim=1) if radix > 1 else nn.Sigmoid(),
            Rearrange('b r k c h w -> b r (k c) h w', r=radix, k=groups)
        )

    def forward(self, x: Tensor) -> Tensor:
        x = rearrange(x, 'b (r k c) h w -> b r (k c) h w', r=self.radix, k=self.groups)
        att = self.att(x)
        # eq 2, scale using att and sum-up over the radix axis
        x *= att
        x = reduce(x, 'b r (k c) h w -> b (k c) h w',
                   reduction='sum', r=self.radix, k=self.groups)
        return x

The bottleneck block

class ResNeStBottleneckBlock(ResNetXtBottleNeckBlock):
    def __init__(self, in_features: int, out_features: int, stride: int = 1, radix: int = 2, groups: int = 1,
                 fast: bool = True, reduction: int = 4, activation: nn.Module = ReLUInPlace, drop_block_p: float = 0, **kwargs):
        super().__init__(in_features, out_features, reduction=reduction,
                         activation=activation, stride=stride, groups=groups, **kwargs)
        att_features = max(self.features * radix // reduction, 32)
        pool = nn.AvgPool2d(kernel_size=3, stride=2,
                            padding=1) if stride == 2 else nn.Identity()
        self.block = nn.Sequential(
            ConvBnDropAct(in_features, self.features, activation=activation,
                          p=drop_block_p, kernel_size=1),
            pool if fast else nn.Identity(),
            ConvBnDropAct(self.features, self.features * radix, activation=activation,
                          p=drop_block_p,  kernel_size=3, groups=groups * radix),
            SplitAtt(self.features, att_features, radix, groups),
            pool if not fast else nn.Identity(),
            ConvBnDropAct(self.features, out_features, activation=activation,
                          p=drop_block_p,  kernel_size=1),
        )

Example

ResNeStBottleneckBlock(64, 64, radix=2, groups=1, base_width=64)
ResNeStBottleneckBlock(
  (block): Sequential(
    (0): ConvBnDropAct(
      (conv): Conv2dPad(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (reg): DropBlock(p=0)
      (act): ReLU(inplace=True)
    )
    (1): Identity()
    (2): ConvBnDropAct(
      (conv): Conv2dPad(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (reg): DropBlock(p=0)
      (act): ReLU(inplace=True)
    )
    (3): SplitAtt(
      (att): Sequential(
        (0): Reduce('b r (k c) h w-> b (k c) h w', 'sum', r=2, k=1)
        (1): AdaptiveAvgPool2d(output_size=1)
        (2): ConvBnAct(
          (conv): Conv2dPad(16, 32, kernel_size=(1, 1), stride=(1, 1))
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): ReLU(inplace=True)
        )
        (3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
        (4): Rearrange('b (r k c) h w -> b r k c h w', r=2, k=1)
        (5): Softmax(dim=1)
        (6): Rearrange('b r k c h w -> b r (k c) h w', r=2, k=1)
      )
    )
    (4): Identity()
    (5): ConvBnDropAct(
      (conv): Conv2dPad(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (reg): DropBlock(p=0)
      (act): ReLU(inplace=True)
    )
  )
  (shortcut): Identity()
  (act): ReLU(inplace=True)
)