FrancescoSaverioZuppichini / glasses

High-quality Neural Networks for Computer Vision 😎
https://francescosaveriozuppichini.github.io/glasses-webapp/
MIT License
431 stars 37 forks source link

Implement SK Module in Glasses #283

Open rentainhe opened 2 years ago

rentainhe commented 2 years ago

Paper

Reference

TODO

codecov-commenter commented 2 years ago

Codecov Report

Merging #283 (61df784) into develop (34300a7) will increase coverage by 0.03%. The diff coverage is 93.18%.

Impacted file tree graph

@@             Coverage Diff             @@
##           develop     #283      +/-   ##
===========================================
+ Coverage    97.28%   97.31%   +0.03%     
===========================================
  Files           86       87       +1     
  Lines         3056     3203     +147     
===========================================
+ Hits          2973     3117     +144     
- Misses          83       86       +3     
Impacted Files Coverage Δ
glasses/nn/att/CBAM.py 100.00% <ø> (ø)
glasses/nn/att/ECA.py 100.00% <ø> (ø)
glasses/utils/Storage.py 95.40% <ø> (+0.16%) :arrow_up:
glasses/nn/att/utils.py 93.75% <83.33%> (-6.25%) :arrow_down:
glasses/nn/att/SK.py 93.10% <93.10%> (ø)
glasses/nn/att/__init__.py 100.00% <100.00%> (ø)
glasses/nn/att/se.py 100.00% <100.00%> (ø)
test/test_att.py 100.00% <100.00%> (ø)
test/test_auto.py 100.00% <0.00%> (ø)
... and 20 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update d0089dc...61df784. Read the comment docs.

rentainhe commented 2 years ago

record the older code

import torch
import torch.nn as nn
from typing import Union, List

from glasses.nn.att.utils import make_divisible
from ..blocks import ConvBnAct
from einops.layers.torch import Rearrange, Reduce

def _kernel_valid(k):
    if isinstance(k, (list, tuple)):
        for ki in k:
            return _kernel_valid(ki)
    assert k >=3 and k % 2

class SelectiveKernelAtt(nn.Module):
    def __init__(
        self,
        features: int,
        num_paths: int = 2,
        mid_features: int = 32,
        act_layer: nn.Module = nn.ReLU,
        norm_layer: nn.Module = nn.BatchNorm2d,
    ):
        super().__init__()
        self.num_paths = num_paths
        self.att = nn.Sequential(
            Reduce("b n c h w -> b c h w", reduction="sum"),
            Reduce("b c h w -> b c 1 1", reduction="mean"),
            nn.Conv2d(features, mid_features, kernel_size=1, bias=False),
            norm_layer(mid_features),
            act_layer(inplace=True),
            nn.Conv2d(mid_features, features * num_paths, kernel_size=1, bias=False),
            Rearrange('b (n c) h w -> b n c h w', n=num_paths, c=features),
            nn.Softmax(dim=1),
        )

    def forward(self, x):
        assert x.shape[1] == self.num_paths
        x = self.att(x)
        return x

class SelectiveKernel(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int = None,
        kernel_size: Union[List, int] = None,
        stride: int = 1,
        dilation: int = 1,
        groups: int = 1,
        reduction: int = 16,
        reduction_divisor: int = 8,
        reduced_features: int = None,
        keep_3x3: bool = True,
        activation: nn.Module = nn.ReLU,
        normalization: nn.Module = nn.BatchNorm2d,
    ):
        super().__init__()
        out_features = out_features or in_features
        kernel_size = kernel_size or [3, 5]
        _kernel_valid(kernel_size)
        if not isinstance(kernel_size, list):
            kernel_size = [kernel_size] * 2
        if keep_3x3:
            dilation = [dilation * (k - 1) // 2 for k in kernel_size]
            kernel_size = [3] * len(kernel_size)
        else:
            dilation = [dilation] * len(kernel_size)
        self.num_paths = len(kernel_size)
        self.in_features = in_features
        self.out_features = out_features,
        groups = min(out_features, groups)

        self.paths = nn.ModuleList([
            ConvBnAct(in_features = in_features, 
                      out_features = out_features, 
                      activation = activation, 
                      normalization=normalization,
                      mode = "same",
                      stride=stride,
                      kernel_size=k, 
                      dilation=d)
            for k, d in zip(kernel_size, dilation)
        ])

        attn_features = reduced_features or make_divisible(out_features // reduction, divisor=reduction_divisor)
        self.attn = SelectiveKernelAtt(out_features, self.num_paths, attn_features)

    def forward(self, x):
        x_paths = [op(x) for op in self.paths]  # b, c, h, w
        x = torch.stack(x_paths, dim=1)  # b, n, c, h, w
        x_attn = self.attn(x)
        x = x * x_attn
        return torch.sum(x, dim=1)
rentainhe commented 2 years ago

Thank you for the PR. Let's

  • add typing
  • remove bad practices such as:
if not isinstance(kernel_size, list):
            kernel_size = [kernel_size] * 2
        if keep_3x3:
            dilation = [1 * (k - 1) // 2 for k in kernel_size]
            kernel_size = [3] * len(kernel_size)
        else:
            dilation = [1 * (k - 1) // 2 for k in kernel_size]
  • decuple each part of the module
  • let the user pass a black, to default ConvBnAct

Sure, I will update my code tonight~, thanks for reviewing