Open zjykzj opened 3 years ago
Same question, posting my implementation for completeness:
class SplitAtt(nn.Module):
def __init__(self, in_features: int, features: int, radix: int, groups: int):
"""Implementation of Split Attention proposed in `"ResNeSt: Split-Attention Networks" <https://arxiv.org/abs/2004.08955>`_
Grouped convolution have been proved to be impirically better (ResNetXt). The main idea is to apply an attention group-wise.
Einops is used to improve the readibility of this module
Args:
in_features (int): number of input features
features (int): attention's features
radix (int): number of subgroups (`radix`) in the groups
groups (int): number of groups, each group contains `radix` subgroups
"""
super().__init__()
self.radix, self.groups = radix, groups
self.att = nn.Sequential(
# this produces U^{/hat}
Reduce('b (r k c) h w -> b (k c) h w',
reduction='mean', r=radix, k=groups),
# eq 1
nn.AdaptiveAvgPool2d(1),
# the two following conv layers are G in the paper
ConvBnAct(in_features, features, kernel_size=1,
groups=groups, activation=ReLUInPlace, bias=True),
nn.Conv2d(features, in_features * radix,
kernel_size=1, groups=groups),
Rearrange('b (r k c) h w -> b r (k c) h w', r=radix, k=groups),
nn.Softmax(dim=1) if radix > 1 else nn.Sigmoid(),
Rearrange('b r (k c) h w -> b (r k c) h w', r=radix, k=groups)
)
def forward(self, x: Tensor) -> Tensor:
att = self.att(x)
# eq 2, scale using att and sum-up over the radix axis
x *= att
x = reduce(x, 'b (r k c) h w -> b (k c) h w',
reduction='mean', r=self.radix, k=self.groups)
return x
btw the bias in the first conv is useless but it is present in the original implementation, I guess it is an error
[Edit] After thinking about it, I think it makes sense because when radix > 1 softmax is applied (in rSoftmax) while when radix=0 sigmoid is used making it the same as SE. But there shouldn't be a batchnorm and a ReLU
Hi @zhanghang1989, may I ask you a question? I don't really get it why bn + relu is applied after the 1 conv in the attention module but no bn + relu is applied in the second one. Thanks :)
Hi @FrancescoSaverioZuppichini , the bn+relu is applied to the first conv, because it adds non-linearity between two convs (otherwise it is equivalent to a single one). There is no bn+relu for the second conv, because the softmax is a kind of non-linearity or activation function.
hi @zhanghang1989 @FrancescoSaverioZuppichini. i modified code to match the same results compare with origin code.
# -*- coding: utf-8 -*-
"""
@date: 2021/1/4 上午11:32
@file: split_attention_conv2d.py
@author: zj
@description:
"""
from abc import ABC
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..init_helper import init_weights
class SplitAttentionConv2d(nn.Module, ABC):
"""
ResNetSt的SplitAttention实现,参考:
1. https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnest.py
2. https://github.com/zhanghang1989/ResNeSt/blob/73b43ba63d1034dbf3e96b3010a8f2eb4cc3854f/resnest/torch/splat.py
部分参考./selective_kernel_conv2d.py实现
"""
def __init__(self,
# 输入通道数
in_channels,
# 输出通道数
out_channels,
# 每个group中的分离数
radix=2,
# cardinality
groups=1,
# 中间层衰减率
reduction_rate=4,
# 默认中间层最小通道数
default_channels: int = 32,
# 维度
dimension: int = 2
):
super(SplitAttentionConv2d, self).__init__()
# split
self.split = nn.Sequential(
nn.Conv2d(in_channels, out_channels * radix, kernel_size=3, stride=1, padding=1, bias=False,
groups=groups * radix),
nn.BatchNorm2d(out_channels * radix),
nn.ReLU(inplace=True)
)
# self.conv1 = nn.Conv2d(in_channels, out_channels * radix, kernel_size=3, stride=1, padding=1, bias=False,
# groups=groups * radix)
# self.bn1 = nn.BatchNorm2d(out_channels * radix)
# self.relu = nn.ReLU(inplace=True)
# fuse
self.pool = nn.AdaptiveAvgPool2d((1, 1))
inner_channels = max(out_channels // reduction_rate, default_channels)
self.compact = nn.Sequential(
nn.Conv2d(out_channels, inner_channels, kernel_size=1, stride=1, padding=0, bias=False,
groups=groups),
nn.BatchNorm2d(inner_channels),
nn.ReLU(inplace=True)
)
# select
self.select = nn.Conv2d(inner_channels, out_channels * radix, kernel_size=1, stride=1,
groups=groups)
# self.softmax = nn.Softmax(dim=0)
self.rsoftmax = rSoftMax(radix, groups)
self.dimension = dimension
self.out_channels = out_channels
self.radix = radix
self.groups = groups
init_weights(self.modules())
def forward(self, x):
N, C, H, W = x.shape[:4]
# split
out = self.split(x)
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
split_out = torch.stack(torch.split(out, self.out_channels, dim=1))
# fuse
u = torch.sum(split_out, dim=0)
s = self.pool(u)
z = self.compact(s)
# select
c = self.select(z)
softmax_c = self.rsoftmax(c).view(N, -1, 1, 1)
attens = torch.split(softmax_c, self.out_channels, dim=1)
v = sum([att * split for (att, split) in zip(attens, split_out)])
return v.contiguous()
# split_c = torch.stack(torch.split(c, self.out_channels, dim=1))
# softmax_c = self.softmax(split_c)
#
# v = torch.sum(split_out.mul(softmax_c), dim=0)
# return v.contiguous()
class rSoftMax(nn.Module):
def __init__(self, radix, cardinality):
super().__init__()
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
# -*- coding: utf-8 -*-
"""
@date: 2021/1/1 下午9:19
@file: test_selective_kernel_conv2d.py
@author: zj
@description:
"""
import numpy as np
import torch
import torch.nn as nn
from resnest.torch.splat import SplAtConv2d
from zcls.model.layers.split_attention_conv2d import SplitAttentionConv2d
def init_weights(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
nn.init.constant_(m.weight, 0.01)
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
# nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.weight, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
def get_custom_res():
# 保证训练时获取的随机数都是一样的
init_seed = 1
torch.manual_seed(init_seed)
torch.cuda.manual_seed(init_seed)
np.random.seed(init_seed)
num = 3
in_channels = 64
out_channels = 128
data = torch.randn(num, in_channels, 56, 56)
# custom
model = SplitAttentionConv2d(in_channels,
out_channels,
radix=2,
groups=32,
reduction_rate=16
)
init_weights(model.modules())
print(model)
outputs_c = model(data)
print(outputs_c.shape)
return outputs_c
def get_official_res():
# 保证训练时获取的随机数都是一样的
init_seed = 1
torch.manual_seed(init_seed)
torch.cuda.manual_seed(init_seed)
np.random.seed(init_seed)
num = 3
in_channels = 64
out_channels = 128
data = torch.randn(num, in_channels, 56, 56)
# official
model = SplAtConv2d(in_channels,
out_channels,
3,
norm_layer=nn.BatchNorm2d,
bias=False,
padding=1,
radix=2,
groups=32,
reduction_factor=16
)
init_weights(model.modules())
print(model)
outputs_o = model(data)
print(outputs_o.shape)
return outputs_o
def compare():
outputs_c = get_custom_res()
outputs_o = get_official_res()
res = torch.allclose(outputs_c, outputs_o)
print(res)
def test_split_attention_conv2d():
num = 3
in_channels = 64
out_channels = 128
data = torch.randn(num, in_channels, 56, 56)
# 不进行分组
model = SplitAttentionConv2d(in_channels,
out_channels,
radix=2,
groups=1,
reduction_rate=4
)
print(model)
outputs = model(data)
print(outputs.shape)
assert outputs.shape == (num, out_channels, 56, 56)
# 不进行radix
model = SplitAttentionConv2d(in_channels,
out_channels,
radix=1,
groups=1,
reduction_rate=16
)
print(model)
outputs = model(data)
print(outputs.shape)
assert outputs.shape == (num, out_channels, 56, 56)
# 同时实现radix和group
model = SplitAttentionConv2d(in_channels,
out_channels,
radix=2,
groups=32,
reduction_rate=16
)
print(model)
outputs = model(data)
print(outputs.shape)
assert outputs.shape == (num, out_channels, 56, 56)
if __name__ == '__main__':
compare()
test_split_attention_conv2d()
SplitAttentionConv2d(
(split): Sequential(
(0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(pool): AdaptiveAvgPool2d(output_size=(1, 1))
(compact): Sequential(
(0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(select): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1), groups=32)
(rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
SplAtConv2d(
(conv): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(bn0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(fc1): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), groups=32)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(fc2): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1), groups=32)
(rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
True
SplitAttentionConv2d(
(split): Sequential(
(0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(pool): AdaptiveAvgPool2d(output_size=(1, 1))
(compact): Sequential(
(0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(select): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1))
(rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
SplitAttentionConv2d(
(split): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(pool): AdaptiveAvgPool2d(output_size=(1, 1))
(compact): Sequential(
(0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(select): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
(rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
SplitAttentionConv2d(
(split): Sequential(
(0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(pool): AdaptiveAvgPool2d(output_size=(1, 1))
(compact): Sequential(
(0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(select): Conv2d(32, 256, kernel_size=(1, 1), stride=(1, 1), groups=32)
(rsoftmax): rSoftMax()
)
torch.Size([3, 128, 56, 56])
Nice! I have also updated mine, I have explicitly rearrange the input tensor to have the radix
axis after the batch. I have tried to clone the pretrained weights but I have terrible results, so probably it is not correct, but I cannot see the error. I don't know why in the first try I used mean
instead of sum
:)
class SplitAtt(nn.Module):
def __init__(self, in_features: int, features: int, radix: int, groups: int):
"""Implementation of Split Attention proposed in `"ResNeSt: Split-Attention Networks" <https://arxiv.org/abs/2004.08955>`_
Grouped convolution have been proved to be impirically better (ResNetXt). The main idea is to apply an attention group-wise.
Einops is used to improve the readibility of this module
Args:
in_features (int): number of input features
features (int): attention's features
radix (int): number of subgroups (`radix`) in the groups
groups (int): number of groups, each group contains `radix` subgroups
"""
super().__init__()
self.radix, self.groups = radix, groups
self.att = nn.Sequential(
# this produces U^{/hat}
Reduce('b r (k c) h w-> b (k c) h w',
reduction='sum', r=radix, k=groups),
# eq 1
nn.AdaptiveAvgPool2d(1),
# the two following conv layers are G in the paper
ConvBnAct(in_features, features, kernel_size=1,
groups=groups, activation=ReLUInPlace, bias=True),
nn.Conv2d(features, in_features * radix,
kernel_size=1, groups=groups),
Rearrange('b (r k c) h w -> b r k c h w', r=radix, k=groups),
nn.Softmax(dim=1) if radix > 1 else nn.Sigmoid(),
Rearrange('b r k c h w -> b r (k c) h w', r=radix, k=groups)
)
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, 'b (r k c) h w -> b r (k c) h w', r=self.radix, k=self.groups)
att = self.att(x)
# eq 2, scale using att and sum-up over the radix axis
x *= att
x = reduce(x, 'b r (k c) h w -> b (k c) h w',
reduction='sum', r=self.radix, k=self.groups)
return x
The bottleneck block
class ResNeStBottleneckBlock(ResNetXtBottleNeckBlock):
def __init__(self, in_features: int, out_features: int, stride: int = 1, radix: int = 2, groups: int = 1,
fast: bool = True, reduction: int = 4, activation: nn.Module = ReLUInPlace, drop_block_p: float = 0, **kwargs):
super().__init__(in_features, out_features, reduction=reduction,
activation=activation, stride=stride, groups=groups, **kwargs)
att_features = max(self.features * radix // reduction, 32)
pool = nn.AvgPool2d(kernel_size=3, stride=2,
padding=1) if stride == 2 else nn.Identity()
self.block = nn.Sequential(
ConvBnDropAct(in_features, self.features, activation=activation,
p=drop_block_p, kernel_size=1),
pool if fast else nn.Identity(),
ConvBnDropAct(self.features, self.features * radix, activation=activation,
p=drop_block_p, kernel_size=3, groups=groups * radix),
SplitAtt(self.features, att_features, radix, groups),
pool if not fast else nn.Identity(),
ConvBnDropAct(self.features, out_features, activation=activation,
p=drop_block_p, kernel_size=1),
)
Example
ResNeStBottleneckBlock(64, 64, radix=2, groups=1, base_width=64)
ResNeStBottleneckBlock(
(block): Sequential(
(0): ConvBnDropAct(
(conv): Conv2dPad(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(reg): DropBlock(p=0)
(act): ReLU(inplace=True)
)
(1): Identity()
(2): ConvBnDropAct(
(conv): Conv2dPad(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2, bias=False)
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(reg): DropBlock(p=0)
(act): ReLU(inplace=True)
)
(3): SplitAtt(
(att): Sequential(
(0): Reduce('b r (k c) h w-> b (k c) h w', 'sum', r=2, k=1)
(1): AdaptiveAvgPool2d(output_size=1)
(2): ConvBnAct(
(conv): Conv2dPad(16, 32, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act): ReLU(inplace=True)
)
(3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
(4): Rearrange('b (r k c) h w -> b r k c h w', r=2, k=1)
(5): Softmax(dim=1)
(6): Rearrange('b r k c h w -> b r (k c) h w', r=2, k=1)
)
)
(4): Identity()
(5): ConvBnDropAct(
(conv): Conv2dPad(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(reg): DropBlock(p=0)
(act): ReLU(inplace=True)
)
)
(shortcut): Identity()
(act): ReLU(inplace=True)
)
hi @zhanghang1989 ,First of all, thank you very much for providing such an imaginative model
I refer to the source code implementation of ResNetSt and reproduce a new implementation of SplitAttentionConv2d. The implementation architecture may be clearer
and one of my question is why there is no need to add bn2/relu in Bottleneck when radix>0,Is it obtained through experiments?