Open Senwang98 opened 3 years ago
Hi, @Paper99 I am try to reproduct RCAN based on your code. my code:
import torch from torch import nn as nn import math # from basicsr.models.archs.arch_util import Upsample, make_layer def make_layer(basic_block, num_basic_block, **kwarg): """Make layers by stacking the same blocks. Args: basic_block (nn.module): nn.module class for basic block. num_basic_block (int): number of blocks. Returns: nn.Sequential: Stacked blocks in nn.Sequential. """ layers = [] for _ in range(num_basic_block): layers.append(basic_block(**kwarg)) return nn.Sequential(*layers) class Upsample(nn.Sequential): """Upsample module. Args: scale (int): Scale factor. Supported scales: 2^n and 3. num_feat (int): Channel number of intermediate features. """ def __init__(self, scale, num_feat): m = [] if (scale & (scale - 1)) == 0: # scale = 2^n for _ in range(int(math.log(scale, 2))): m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) m.append(nn.PixelShuffle(2)) elif scale == 3: m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) m.append(nn.PixelShuffle(3)) else: raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') super(Upsample, self).__init__(*m) class ChannelAttention(nn.Module): """Channel attention used in RCAN. Args: num_feat (int): Channel number of intermediate features. squeeze_factor (int): Channel squeeze factor. Default: 16. """ def __init__(self, num_feat, squeeze_factor=16): super(ChannelAttention, self).__init__() self.attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid()) def forward(self, x): y = self.attention(x) return x * y class RCAB(nn.Module): """Residual Channel Attention Block (RCAB) used in RCAN. Args: num_feat (int): Channel number of intermediate features. squeeze_factor (int): Channel squeeze factor. Default: 16. res_scale (float): Scale the residual. Default: 1. """ def __init__(self, num_feat, squeeze_factor=16, res_scale=1): super(RCAB, self).__init__() self.res_scale = res_scale self.rcab = nn.Sequential( nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1), ChannelAttention(num_feat, squeeze_factor)) def forward(self, x): res = self.rcab(x) * self.res_scale return res + x class ResidualGroup(nn.Module): """Residual Group of RCAB. Args: num_feat (int): Channel number of intermediate features. num_block (int): Block number in the body network. squeeze_factor (int): Channel squeeze factor. Default: 16. res_scale (float): Scale the residual. Default: 1. """ def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1): super(ResidualGroup, self).__init__() self.residual_group = make_layer( RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale) self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1) def forward(self, x): res = self.conv(self.residual_group(x)) return res + x class RCAN(nn.Module): """Residual Channel Attention Networks. Paper: Image Super-Resolution Using Very Deep Residual Channel Attention Networks Ref git repo: https://github.com/yulunzhang/RCAN. Args: num_in_ch (int): Channel number of inputs. num_out_ch (int): Channel number of outputs. num_feat (int): Channel number of intermediate features. Default: 64. num_group (int): Number of ResidualGroup. Default: 10. num_block (int): Number of RCAB in ResidualGroup. Default: 16. squeeze_factor (int): Channel squeeze factor. Default: 16. upscale (int): Upsampling factor. Support 2^n and 3. Default: 4. res_scale (float): Used to scale the residual in residual block. Default: 1. img_range (float): Image range. Default: 255. rgb_mean (tuple[float]): Image mean in RGB orders. Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset. """ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_group=10, num_block=16, squeeze_factor=16, upscale=2, res_scale=1, img_range=255., rgb_mean=(0.4488, 0.4371, 0.4040)): super(RCAN, self).__init__() self.img_range = img_range self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) self.body = make_layer( ResidualGroup, num_group, num_feat=num_feat, num_block=num_block, squeeze_factor=squeeze_factor, res_scale=res_scale) self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) self.upsample = Upsample(upscale, num_feat) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) def forward(self, x): # print(x.shape) self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range x = self.conv_first(x) res = self.conv_after_body(self.body(x)) res += x x = self.conv_last(self.upsample(res)) x = x / self.img_range + self.mean # print(x.shape) # exit() return x
this code can run, but the loss is very high(about 1e30). I feel so confused about this, can you give me suggestions? my train.json:
{ "mode": "sr", "use_cl": false, // "use_cl": true, "gpu_ids": [1], "scale": 2, "is_train": true, "use_chop": true, "rgb_range": 255, "self_ensemble": false, "save_image": false, "datasets": { "train": { "mode": "LRHR", "dataroot_HR": "/home/wangsen/ws/dataset/DIV2K/Augment/DIV2K_train_HR_aug/x2", "dataroot_LR": "/home/wangsen/ws/dataset/DIV2K/Augment/DIV2K_train_LR_aug/x2", "data_type": "npy", "n_workers": 8, "batch_size": 16, "LR_size": 48, "use_flip": true, "use_rot": true, "noise": "." }, "val": { "mode": "LRHR", "dataroot_HR": "./results/HR/Set5/x2", "dataroot_LR": "./results/LR/LRBI/Set5/x2", "data_type": "img" } }, "networks": { "which_model": "RCAN", "num_features": 64, "in_channels": 3, "out_channels": 3, "res_scale": 1, "num_resgroups":10, "num_resblocks":20, "num_reduction":16 }, "solver": { "type": "ADAM", "learning_rate": 0.0002, "weight_decay": 0, "lr_scheme": "MultiStepLR", "lr_steps": [200, 400, 600, 800], "lr_gamma": 0.5, "loss_type": "l1", "manual_seed": 0, "num_epochs": 1000, "skip_threshold": 3, "split_batch": 1, "save_ckp_step": 100, "save_vis_step": 1, "pretrain": null, // "pretrain": "resume", "pretrained_path": "./experiments/RCAN_in3f64_x4/epochs/last_ckp.pth", "cl_weights": [1.0, 1.0, 1.0, 1.0] } }
Hi, @Paper99 I am try to reproduct RCAN based on your code. my code:
this code can run, but the loss is very high(about 1e30). I feel so confused about this, can you give me suggestions? my train.json: