rebeccaeexu / RRID

[ECCV 2024] Image Demoireing in RAW and sRGB Domains
4 stars 0 forks source link

is FSM block work on frequency domain? #3

Closed jeahun00 closed 1 month ago

jeahun00 commented 1 month ago

First of all, thank you for sharing your wonderful work. 🔥

While reviewing this code, I have a question about the FSM block. In the FSM block, as shown in the figure below, it passes through the DCT.

image

I understand that the conv_dct_0 function corresponds to the DCT block in the figure. However, the function does not seem to include the FFT or cosine operation, which is typically used to implement the DCT (transforming into the frequency domain). I'm curious if it's possible to implement the DCT using only convolution operations.

class FSM(nn.Module):
    def __init__(self, num_feat=40, block_size=8):
        super(FSM, self).__init__()
        self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_dct_0 = directional_dct_layer(in_c=num_feat, out_c=num_feat)

        self.dct_weight = dct_weight(in_c=num_feat, out_c=num_feat)
        self.in_c = num_feat
        self.bs = block_size

        self.after_rdct = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.act = nn.GELU()

    def forward(self, x):
        _, _, h, w = x.size()
        dct_feat = self.act(self.conv(x))
        # mode 0
        dct_feat_0 = self.conv_dct_0(dct_feat)
        out_0 = directional_inverse_dct_layer(dct_feat_0, bs=self.bs, mode=0)
        out_0 = F.interpolate(out_0, size=(h, w), mode='bilinear', align_corners=False)

        dct_weight = self.dct_weight(x)
        out = torch.mul(out_0, dct_weight)

        out = self.after_rdct(out)
        return out

class directional_dct_layer(nn.Module):
    def __init__(self, in_c=40, out_c=40):
        super(directional_dct_layer, self).__init__()
        self.ddct_conv = nn.Conv2d(in_c, in_c, 3, 1, 1)
        self.act = nn.GELU()
        self.conv_dct = nn.Conv2d(in_c, out_c, kernel_size=3, stride=8, padding=2, dilation=2, groups=in_c)

    def forward(self, x):
        out = self.act(self.ddct_conv(x)) + x
        out = self.conv_dct(out)
        return out
rebeccaeexu commented 1 month ago

IDCT/IFFT can transform the frequency domain to the spatial domain. Through network learning, when we employ IDCT, the convolutional part will undertake the process of converting s2f to match the IDCT's f2s operation. In our experiments, this proves to be an efficient and viable approach.

jeahun00 commented 1 month ago

Thank you for your response. I still have some doubts, so I would like to ask more questions.

It is possible to transform from the frequency domain to the spatial domain using IDCT/IFFT. I understood that through network training, it is possible to induce a transformation that is symmetric to IDCT, that is, a substitution from the spatial domain to the frequency domain.

Is my understanding correct?

rebeccaeexu commented 1 month ago

Yes

jeahun00 commented 1 month ago

Thank you for your kind response! 👍