raoyongming / GFNet

[NeurIPS 2021] [T-PAMI] Global Filter Networks for Image Classification
https://gfnet.ivg-research.xyz/
MIT License
445 stars 41 forks source link

Flexible input size #16

Closed repers closed 2 years ago

repers commented 2 years ago

Hi, I came across your work and thought it was a very interesting concept. Currently, the network takes in fixed input sizes. But is there a way for there to be flexible input sizes? I realize the main constraint here is the following line where the complex weight is defined during initialization time: self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02) Is there a way to modify this line so that we can have inputs of different sizes? Thanks

raoyongming commented 2 years ago

Hi, thanks for your interest in our work. We adaptively resize the complex weights in forward for tasks that require flexible input sizes like object detection and semantic segmentation. We find simply (bilinearly) interpreting the complex weights works well on these tasks. The GFLayer can be modified as:

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)

    def forward(self, x):
        x = x.to(torch.float32)
        B, C, a, b = x.shape
        x = torch.fft.rfft2(x, dim=(2, 3), norm='ortho')

        weight = self.complex_weight
        if not weight.shape[1:3] == x.shape[2:4]:
            weight = F.interpolate(weight.permute(3,0,1,2), size=x.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)

        weight = torch.view_as_complex(weight.contiguous())

        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(2, 3), norm='ortho')
        return x

where the complex_weight is still initialized with a fixed size like the classification task.

repers commented 2 years ago

Thanks a lot for this, I will try this out and see how it works!

repers commented 2 years ago

Hi, so after training a network which previously contained ViT and replaced it with GFNet (and added the snippet of code to ensure flexible input size) there was a significant performance drop of around 0.7dB (image-based task) and training was around 2x slower. My understanding was fourier operations were quicker than the standard attention, yet for some reason this isn't the case for me, am I missing something here. Do GFNet-based models require a different training procedure to ViT to match performance?

raoyongming commented 2 years ago

My experience is that it is useful to add normalization layers before and after the GF layers, which may stabilize the training and improve performance. If you replace the whole self-attention block (qkv projections, matmul, softmax and out project) with a GF layer, the complexity of the model will be largely reduced, which can lead to lower performance.

It is a bit weird that GF layers are much slower than self-attention. Did you directly replace the standard self-attention with our GF layer? From my understanding, low-level vision tasks usually need to process high-resolution feature maps, which makes directly applying self-attention to such features very difficult.

repers commented 2 years ago

Hi, thanks for your response, I thought I replied to this, but seems I forgot to press the comment button. I did indeed replace standard self-attention with the GF layer. I used the pvt-v2 code and replaced the attention block with the following code:

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.rand(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = h
        self.h = w

    def forward(self, x, h,w): #(H,W, images dimensions used for other operations)
        B, N, C = x.shape
        a = h
        b = w
        x = x.view(B, C, a, b)
        x = x.to(torch.float32)
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = self.complex_weight
        if not weight.shape[1:3] == x.shape[2:4]:
            weight = F.interpolate(weight.permute(3,0,1,2), size=x.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)
        weight = torch.view_as_complex(weight.contiguous())
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(2, 3), norm='ortho')
        x = x.reshape(B, N, C)
        return x

The configuration was the same as PVT relative to the depth, MLP heads etc, just the attention was changed. A norm was used (as per the default configuration of PVT) and is defined as:

        x =x +self.drop_path(self.filter(self.norm1(x),H,W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

Should I add another norm layer as such:

        x =x + self.norm3(self.drop_path(self.filter(self.norm1(x),H,W)))

I'm not sure what can be done as the approach is conceptually promising but my implementation has the issues disccused Thanks again for your help!

raoyongming commented 2 years ago

PVT-v2 didn't use the standard attention layer, where self-attention is performed on the downsampled feature maps in the modified attention block. Since the parameters in a single GF layer are much fewer than those in a self-attention layer, it is possible that the performance will drop. It also seems necessary to use normalization layers before and after GF layers according to my experience.

To achieve better performance, I would recommend using our new block in HorNet. You can replace the whole PVT-v2 block (including SA and FFN) with the HorNet block.

repers commented 2 years ago

Thanks for your response, does HorNet work only on fixed size inputs or is it like PVT which accepts inputs of any size?

raoyongming commented 2 years ago

You can refer to our models for object detection, which is designed for any input size.

MiaoJieF commented 1 year ago

Hi, thanks for your response, I thought I replied to this, but seems I forgot to press the comment button. I did indeed replace standard self-attention with the GF layer. I used the pvt-v2 code and replaced the attention block with the following code:

class GlobalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.complex_weight = nn.Parameter(torch.rand(h, w, dim, 2, dtype=torch.float32) * 0.02)
        self.w = h
        self.h = w

    def forward(self, x, h,w): #(H,W, images dimensions used for other operations)
        B, N, C = x.shape
        a = h
        b = w
        x = x.view(B, C, a, b)
        x = x.to(torch.float32)
        x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
        weight = self.complex_weight
        if not weight.shape[1:3] == x.shape[2:4]:
            weight = F.interpolate(weight.permute(3,0,1,2), size=x.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)
        weight = torch.view_as_complex(weight.contiguous())
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(2, 3), norm='ortho')
        x = x.reshape(B, N, C)
        return x

The configuration was the same as PVT relative to the depth, MLP heads etc, just the attention was changed. A norm was used (as per the default configuration of PVT) and is defined as:

        x =x +self.drop_path(self.filter(self.norm1(x),H,W))
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

Should I add another norm layer as such:

        x =x + self.norm3(self.drop_path(self.filter(self.norm1(x),H,W)))

I'm not sure what can be done as the approach is conceptually promising but my implementation has the issues disccused Thanks again for your help!

Is there a problem with your interpolation code? Since the initial dimension order of 'self.complex_weight' has changed, it should also be changed during interpolation.