kkkls / FFTformer

[CVPR 2023] Effcient Frequence Domain-based Transformer for High-Quality Image Deblurring
MIT License
255 stars 18 forks source link

FSAS模块疑问 #22

Open lijun2005 opened 10 months ago

lijun2005 commented 10 months ago
class FSAS(nn.Module):
    def __init__(self, dim, bias):
        super(FSAS, self).__init__()

        self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias)
        self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias)

        self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias)

        self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')

        self.patch_size = 8

    def forward(self, x):
        hidden = self.to_hidden(x)

        q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1)

        q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
                            patch2=self.patch_size)
        k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size,
                            patch2=self.patch_size)
        q_fft = torch.fft.rfft2(q_patch.float())
        k_fft = torch.fft.rfft2(k_patch.float())

        out = q_fft * k_fft
        out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
        out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size,
                        patch2=self.patch_size)

        out = self.norm(out)

        output = v * out
        output = self.project_out(output)

        return output

阅读了FSAS代码,发现这并不算是自注意力计算吧,output的计算方式是

output = v*out

这应该是spatial domain上的门控机制吧

lijun2005 commented 10 months ago

同时从论文原文来看,上述FSAS模块的计算只是解决了空域上Q K的矩阵乘法计算,但同时softmax也是自注意力计算的核心操作,但是原始论文和代码中都并没有讨论softmax操作在FSAS上的作用,也并没有对softmax操作进行消融实验对比

image
kkkls commented 10 months ago

你好,目前已经有一些工作证明了不需要softmax也可以取得很好的效果例如:“Transformer Quality in Linear Time”,self_attention的核心并不是softmax,softmax只是一个非线性的函数,是可以换成别的函数的