Closed 123456789-qwer closed 2 years ago
Hi, thanks for your interest in our paper.
The rFFT of a real signal with a shape of HxW will only have Hx(W//2+1) independent components due to the symmetry of the transformed signal. Therefore, for a real tensor with a shape of Bx14x14xC (the input feature), torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
yields a Bx14x8xC complex tensor. Therefore, we set the shape of the filter to 14x8xdimx2 (2 for a complex tensor). Given a Bx14x8xC complex feature, either returning a Bx14x14xC feature or a Bx14x15xC feature from irfft is reasonable. Therefore, we should specify the shape of the output feature s=(H,W)=(14, 14)
.
class GlobalFilter(nn.Module): def init(self, dim, h=14, w=8): super().init() self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
Thank you very much for your work. I have some questions. What's meaning of "h=14, w=8", "s=(H, W), dim=(1, 2)".