Open wangxinchao-bit opened 1 year ago
import torch import torch.nn as nn import torch.fft
class GlobalFilter(nn.Module): def init(self, dim, h=14, w=8): super().init() self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
def forward(self, x): B, H, W, C = x.shape x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') weight = torch.view_as_complex(self.complex_weight) x = x * weight x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho') return x
Can you give a simple test for the dimensions' change?
https://github.com/raoyongming/GFNet/issues/13#issuecomment-1100144292
import torch import torch.nn as nn import torch.fft
class GlobalFilter(nn.Module): def init(self, dim, h=14, w=8): super().init() self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
Can you give a simple test for the dimensions' change?