raoyongming / GFNet

[NeurIPS 2021] [T-PAMI] Global Filter Networks for Image Classification
https://gfnet.ivg-research.xyz/
MIT License
448 stars 42 forks source link

parameters #13

Closed 123456789-qwer closed 2 years ago

123456789-qwer commented 2 years ago

class GlobalFilter(nn.Module): def init(self, dim, h=14, w=8): super().init() self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)

def forward(self, x):
    B, H, W, C = x.shape
    x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
    weight = torch.view_as_complex(self.complex_weight)
    x = x * weight
    x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
    return x

Thank you very much for your work. I have some questions. What's meaning of "h=14, w=8", "s=(H, W), dim=(1, 2)".

raoyongming commented 2 years ago

Hi, thanks for your interest in our paper.

The rFFT of a real signal with a shape of HxW will only have Hx(W//2+1) independent components due to the symmetry of the transformed signal. Therefore, for a real tensor with a shape of Bx14x14xC (the input feature), torch.fft.rfft2(x, dim=(1, 2), norm='ortho') yields a Bx14x8xC complex tensor. Therefore, we set the shape of the filter to 14x8xdimx2 (2 for a complex tensor). Given a Bx14x8xC complex feature, either returning a Bx14x14xC feature or a Bx14x15xC feature from irfft is reasonable. Therefore, we should specify the shape of the output feature s=(H,W)=(14, 14).