Paper99 / SRFBN_CVPR19

Pytorch code for our paper "Feedback Network for Image Super-Resolution" (CVPR2019)
MIT License
551 stars 126 forks source link

problem about reproducting RCAN using your project #58

Open Senwang98 opened 3 years ago

Senwang98 commented 3 years ago

Hi, @Paper99 I am try to reproduct RCAN based on your code. my code:

import torch
from torch import nn as nn
import math
# from basicsr.models.archs.arch_util import Upsample, make_layer

def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.
    Args:
        basic_block (nn.module): nn.module class for basic block.
        num_basic_block (int): number of blocks.
    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)

class Upsample(nn.Sequential):
    """Upsample module.
    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.
    """

    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. '
                             'Supported scales: 2^n and 3.')
        super(Upsample, self).__init__(*m)

class ChannelAttention(nn.Module):
    """Channel attention used in RCAN.
    Args:
        num_feat (int): Channel number of intermediate features.
        squeeze_factor (int): Channel squeeze factor. Default: 16.
    """

    def __init__(self, num_feat, squeeze_factor=16):
        super(ChannelAttention, self).__init__()
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
            nn.Sigmoid())

    def forward(self, x):
        y = self.attention(x)
        return x * y

class RCAB(nn.Module):
    """Residual Channel Attention Block (RCAB) used in RCAN.
    Args:
        num_feat (int): Channel number of intermediate features.
        squeeze_factor (int): Channel squeeze factor. Default: 16.
        res_scale (float): Scale the residual. Default: 1.
    """

    def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
        super(RCAB, self).__init__()
        self.res_scale = res_scale

        self.rcab = nn.Sequential(
            nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(num_feat, num_feat, 3, 1, 1),
            ChannelAttention(num_feat, squeeze_factor))

    def forward(self, x):
        res = self.rcab(x) * self.res_scale
        return res + x

class ResidualGroup(nn.Module):
    """Residual Group of RCAB.
    Args:
        num_feat (int): Channel number of intermediate features.
        num_block (int): Block number in the body network.
        squeeze_factor (int): Channel squeeze factor. Default: 16.
        res_scale (float): Scale the residual. Default: 1.
    """

    def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
        super(ResidualGroup, self).__init__()

        self.residual_group = make_layer(
            RCAB,
            num_block,
            num_feat=num_feat,
            squeeze_factor=squeeze_factor,
            res_scale=res_scale)
        self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)

    def forward(self, x):
        res = self.conv(self.residual_group(x))
        return res + x

class RCAN(nn.Module):
    """Residual Channel Attention Networks.
    Paper: Image Super-Resolution Using Very Deep Residual Channel Attention
        Networks
    Ref git repo: https://github.com/yulunzhang/RCAN.
    Args:
        num_in_ch (int): Channel number of inputs.
        num_out_ch (int): Channel number of outputs.
        num_feat (int): Channel number of intermediate features.
            Default: 64.
        num_group (int): Number of ResidualGroup. Default: 10.
        num_block (int): Number of RCAB in ResidualGroup. Default: 16.
        squeeze_factor (int): Channel squeeze factor. Default: 16.
        upscale (int): Upsampling factor. Support 2^n and 3.
            Default: 4.
        res_scale (float): Used to scale the residual in residual block.
            Default: 1.
        img_range (float): Image range. Default: 255.
        rgb_mean (tuple[float]): Image mean in RGB orders.
            Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
    """

    def __init__(self,
                 num_in_ch=3,
                 num_out_ch=3,
                 num_feat=64,
                 num_group=10,
                 num_block=16,
                 squeeze_factor=16,
                 upscale=2,
                 res_scale=1,
                 img_range=255.,
                 rgb_mean=(0.4488, 0.4371, 0.4040)):
        super(RCAN, self).__init__()

        self.img_range = img_range
        self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)

        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(
            ResidualGroup,
            num_group,
            num_feat=num_feat,
            num_block=num_block,
            squeeze_factor=squeeze_factor,
            res_scale=res_scale)
        self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.upsample = Upsample(upscale, num_feat)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

    def forward(self, x):
        # print(x.shape)
        self.mean = self.mean.type_as(x)

        x = (x - self.mean) * self.img_range
        x = self.conv_first(x)
        res = self.conv_after_body(self.body(x))
        res += x

        x = self.conv_last(self.upsample(res))
        x = x / self.img_range + self.mean
        # print(x.shape)
        # exit()
        return x

this code can run, but the loss is very high(about 1e30). I feel so confused about this, can you give me suggestions? my train.json:

{
    "mode": "sr",
    "use_cl": false,
    // "use_cl": true,
    "gpu_ids": [1],

    "scale": 2,
    "is_train": true,
    "use_chop": true,
    "rgb_range": 255,
    "self_ensemble": false,
    "save_image": false,

    "datasets": {
        "train": {
            "mode": "LRHR",
            "dataroot_HR": "/home/wangsen/ws/dataset/DIV2K/Augment/DIV2K_train_HR_aug/x2",
            "dataroot_LR": "/home/wangsen/ws/dataset/DIV2K/Augment/DIV2K_train_LR_aug/x2",
            "data_type": "npy",
            "n_workers": 8,
            "batch_size": 16,
            "LR_size": 48,
            "use_flip": true,
            "use_rot": true,
            "noise": "."
        },
        "val": {
            "mode": "LRHR",
            "dataroot_HR": "./results/HR/Set5/x2",
            "dataroot_LR": "./results/LR/LRBI/Set5/x2",
            "data_type": "img"
        }
    },

    "networks": {
        "which_model": "RCAN",
        "num_features": 64,
        "in_channels": 3,
        "out_channels": 3,
        "res_scale": 1,
        "num_resgroups":10, 
        "num_resblocks":20,
        "num_reduction":16
    },

    "solver": {
        "type": "ADAM",
        "learning_rate": 0.0002,
        "weight_decay": 0,
        "lr_scheme": "MultiStepLR",
        "lr_steps": [200, 400, 600, 800],
        "lr_gamma": 0.5,
        "loss_type": "l1",
        "manual_seed": 0,
        "num_epochs": 1000,
        "skip_threshold": 3,
        "split_batch": 1,
        "save_ckp_step": 100,
        "save_vis_step": 1,
        "pretrain": null,
        // "pretrain": "resume",
        "pretrained_path": "./experiments/RCAN_in3f64_x4/epochs/last_ckp.pth",
        "cl_weights": [1.0, 1.0, 1.0, 1.0]
    }
}