megvii-research / NAFNet

The state-of-the-art image restoration model without nonlinear activation functions.
Other
2.18k stars 267 forks source link

Some questions about NAFBlock #125

Closed cheng221 closed 10 months ago

cheng221 commented 1 year ago

Hello! Thanks for your great work advancing the progress of image restoration. Recently, I conducted some experiments about the skip connection of NAFBlock. Because the form of skip connection in code is y = inp + x * self.beta, which has additional parameters control the output of the convolutional branch. And when replacing y = inp + x * self.beta with y = inp + x, the performance will drop by about 0.2 PSNR in my settings. So I want to know why it affects the performance. Finally, are there some kinds of literature or blogs discussing this design? Thanks for taking your valuable time to answer my question!

class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
                               bias=True)
        self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        # Simplified Channel Attention
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
                      groups=1, bias=True),
        )

        # SimpleGate
        self.sg = SimpleGate()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
        self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp):
        x = inp

        x = self.norm1(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)

        x = self.dropout1(x)

        y = inp + x * self.beta

        x = self.conv4(self.norm2(y))
        x = self.sg(x)
        x = self.conv5(x)

        x = self.dropout2(x)

        return y + x * self.gamma
tseets42 commented 10 months ago

This is called Skip Init and it helps stabilize training. (https://paperswithcode.com/method/skipinit)

cheng221 commented 10 months ago

This is called Skip Init and it helps stabilize training. (https://paperswithcode.com/method/skipinit)

Thank you kindly for your answer, it has been incredibly helpful to me. Have a good day!