raoyongming / GFNet

[NeurIPS 2021] [T-PAMI] Global Filter Networks for Image Classification
https://gfnet.ivg-research.xyz/
MIT License
448 stars 42 forks source link

image size for ADE20K #4

Closed ShoufaChen closed 3 years ago

ShoufaChen commented 3 years ago

Hi, Yongming

What is image size did you use for training and validation on ADE20K?

I noticed that PVT used 512x512 for training and a different scale for testing. However, as the parameters of Global Filter are related to the image size, how do you deal with the scale change?

Thanks in advance.

raoyongming commented 3 years ago

We use 512x512 images during training. During inference, we directly upsample the filter to make it suitable for larger images. Results in Figure 6 also show that our model can better adapt to different image sizes compared to transformers. The following is our implementation for image segmentation experiments.


class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = w
        self.h = h
        self.dim = dim

    def forward(self, x):
        B, a, b, C = x.shape
        x = x.to(torch.float32)

        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')

        weight = self.complex_weight
        if not self.training:
            _, H, W, _ = x.shape
            Hw, Ww = self.complex_weight.shape[:2]
            if Hw != H or Ww != W:
                weight = weight.view(1, Hw, Ww, 2 * self.dim).permute(0, 3, 1, 2)
                weight = F.interpolate(
                    weight, size=(H, W), mode='bicubic', align_corners=True).permute(0, 2, 3, 1).reshape(H, W, self.dim, 2).contiguous()
        weight = torch.view_as_complex(weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')

        return x
ShoufaChen commented 3 years ago

Thanks for your reply.