chinhsuanwu / coatnet-pytorch

A PyTorch implementation of "CoAtNet: Marrying Convolution and Attention for All Data Sizes"
https://arxiv.org/abs/2106.04803
MIT License
367 stars 67 forks source link

Inconsistencies in MBConv, with corrected code provided #19

Open swarajnanda2021 opened 8 months ago

swarajnanda2021 commented 8 months ago

I've found the MBConv to have some computational inconsistencies. The following corrected code works, where I've changed the stride of the projection operation (self.proj) and moved it out of the if downsample statement. Further, the squeeze and excite block has been appropriately initialized (I've added my squeeze and excite block too here for completeness). I've also added the channel projection operation on the downsample is false branch of MBConv forward method:


class SqueezeAndExcite(nn.Module):
    def __init__(self, in_channels, expansion=0.25): # keep the reduction fixed
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, int(in_channels * expansion)),
            nn.GELU(),
            nn.Linear(int(in_channels * expansion), in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class MBConv(nn.Module):
    def __init__(self, inp, oup, expansion, downsample):
        super().__init__()
        self.downsample = downsample
        stride = 1 if not downsample else 2
        hidden_dim = int(expansion * inp)

        if self.downsample:
            self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.proj = nn.Conv2d(inp, oup, kernel_size=1, stride=1, padding=0, bias=False)

        if expansion == 1:
            self.conv = nn.Sequential(
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, 
                          padding=1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(oup)
            )
        else:
            self.conv = nn.Sequential(
                nn.Conv2d(inp, hidden_dim, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, 
                          padding=1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.GELU(),
                SqueezeAndExcite(hidden_dim, expansion=0.25),
                nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(oup)
            )

        self.conv = PreNorm(norm=nn.BatchNorm2d, model=self.conv, dimension=inp)

    def forward(self, x):
        if self.downsample:
            return self.proj(self.pool(x)) + self.conv(x)
        else:
            return self.proj(x) + self.conv(x)
Uljibuh commented 7 months ago

hi! just curious that did you run into this issue? how did you solved it?

https://github.com/chinhsuanwu/coatnet-pytorch/issues/20

swarajnanda2021 commented 7 months ago

I was implementing CoATNet myself and sought this repo for inspiration. It did not work, so while debugging I had to re-read the paper several times. Finally I understood the problems and accordingly found a solution. Of course, GPT4 helped a lot here.

Uljibuh commented 7 months ago

how was the training results of the model? did you use downsampling ? which one gives better results? with downsamling or without downsampling?