raoyongming / GFNet

[NeurIPS 2021] [T-PAMI] Global Filter Networks for Image Classification
https://gfnet.ivg-research.xyz/
MIT License
445 stars 41 forks source link

The dimenson for Global Filter #31

Open wangxinchao-bit opened 1 year ago

wangxinchao-bit commented 1 year ago

import torch import torch.nn as nn import torch.fft

class GlobalFilter(nn.Module): def init(self, dim, h=14, w=8): super().init() self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)

def forward(self, x):
    B, H, W, C = x.shape
    x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
    weight = torch.view_as_complex(self.complex_weight)
    x = x * weight
    x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
    return x  

Can you give a simple test for the dimensions' change?

ShiZican commented 11 months ago

https://github.com/raoyongming/GFNet/issues/13#issuecomment-1100144292