Closed Senwang98 closed 3 years ago
Hi, @njulj I use 6 rfdb and change ESA to CCA, but the param is 544386 which is less than that in your paper. I am not sure is there anything wrong in my code. BTW, I set nf = 48 already.
import torch import torch.nn as nn from model import block as B # import block as B def make_model(args, parent=False): model = RFDN() return model class RFDN(nn.Module): def __init__(self, in_nc=3, nf=48, num_modules=6, out_nc=3, upscale=4): super(RFDN, self).__init__() self.fea_conv = B.conv_layer(in_nc, nf, kernel_size=3) self.B1 = B.RFDB(in_channels=nf) self.B2 = B.RFDB(in_channels=nf) self.B3 = B.RFDB(in_channels=nf) self.B4 = B.RFDB(in_channels=nf) self.B5 = B.RFDB(in_channels=nf) self.B6 = B.RFDB(in_channels=nf) self.c = B.conv_block(nf * num_modules, nf, kernel_size=1, act_type='lrelu') self.LR_conv = B.conv_layer(nf, nf, kernel_size=3) upsample_block = B.pixelshuffle_block self.upsampler = upsample_block(nf, out_nc, upscale_factor=4) self.scale_idx = 2 def forward(self, input): # print(input.shape) # exit() out_fea = self.fea_conv(input) out_B1 = self.B1(out_fea) out_B2 = self.B2(out_B1) out_B3 = self.B3(out_B2) out_B4 = self.B4(out_B3) out_B5 = self.B4(out_B4) out_B6 = self.B4(out_B5) out_B = self.c(torch.cat([out_B1, out_B2, out_B3, out_B4, out_B5, out_B6], dim=1)) out_lr = self.LR_conv(out_B) + out_fea output = self.upsampler(out_lr) # print(output.shape) # exit() return output def set_scale(self, scale_idx): self.scale_idx = scale_idx
class RFDB(nn.Module): def __init__(self, in_channels, distillation_rate=0.25): super(RFDB, self).__init__() self.dc = self.distilled_channels = in_channels//2 self.rc = self.remaining_channels = in_channels self.c1_d = conv_layer(in_channels, self.dc, 1) self.c1_r = conv_layer(in_channels, self.rc, 3) self.c2_d = conv_layer(self.remaining_channels, self.dc, 1) self.c2_r = conv_layer(self.remaining_channels, self.rc, 3) self.c3_d = conv_layer(self.remaining_channels, self.dc, 1) self.c3_r = conv_layer(self.remaining_channels, self.rc, 3) self.c4 = conv_layer(self.remaining_channels, self.dc, 3) self.act = activation('lrelu', neg_slope=0.05) self.c5 = conv_layer(self.dc*4, in_channels, 1) # self.esa = ESA(in_channels, nn.Conv2d) self.cca = CCALayer(in_channels) def forward(self, input): distilled_c1 = self.act(self.c1_d(input)) r_c1 = (self.c1_r(input)) r_c1 = self.act(r_c1+input) distilled_c2 = self.act(self.c2_d(r_c1)) r_c2 = (self.c2_r(r_c1)) r_c2 = self.act(r_c2+r_c1) distilled_c3 = self.act(self.c3_d(r_c2)) r_c3 = (self.c3_r(r_c2)) r_c3 = self.act(r_c3+r_c2) r_c4 = self.act(self.c4(r_c3)) out = torch.cat([distilled_c1, distilled_c2, distilled_c3, r_c4], dim=1) # out_fused = self.esa(self.c5(out)) out_fused = self.cca(self.c5(out)) return out_fused class CCALayer(nn.Module): def __init__(self, channel, reduction=16): super(CCALayer, self).__init__() self.contrast = stdv_channels self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.contrast(x) + self.avg_pool(x) y = self.conv_du(y) return x * y
Hi, @njulj I use 6 rfdb and change ESA to CCA, but the param is 544386 which is less than that in your paper. I am not sure is there anything wrong in my code. BTW, I set nf = 48 already.