xinntao / ESRGAN

ECCV18 Workshops - Enhanced SRGAN. Champion PIRM Challenge on Perceptual Super-Resolution. The training codes are in BasicSR.
https://github.com/xinntao/BasicSR
Apache License 2.0
5.91k stars 1.05k forks source link

how to test after train #58

Closed yja1 closed 5 years ago

yja1 commented 5 years ago

I have make myself lmdb HR and LR ,and run use code "BasicSR" python train.py -opt /BasicSR/codes/options/train/train_ESRGAN.json

but how to test my data image use my trained model?

in BasicSR test.py or ESRGAN test.py ? in fact all failed

xinntao commented 5 years ago

I think both are OK. But it may need some modifications according to your needs.

shiyangjing commented 5 years ago

I have make myself lmdb HR and LR ,and run use code "BasicSR" python train.py -opt /BasicSR/codes/options/train/train_ESRGAN.json

but how to test my data image use my trained model?

in BasicSR test.py or ESRGAN test.py ? in fact all failed

When you test your data images by changing model path——python test.py models/your_trained_model_path.pth, is it OK ?(ESRGAN test.py)

yja1 commented 5 years ago

model_path = 'models_pth/115000_G.pth' device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu model = arch.RRDBNet(3, 3, 64, 23, gc=32) model.load_state_dict(torch.load(model_path), strict=True) model.eval()

error: model.load_state_dict(torch.load(model_path), strict=True) File "/env_pyt0.4.1_py3.6_centernetobj/lib/python3.6/site-packages/torch/nn/modules/module.py", line 719, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for RRDBNet: Missing key(s) in state_dict: "conv_first.weight", "conv_first.bias", "RRDB_trunk.0.RDB1.conv1.weight",.........

xinntao commented 5 years ago

We have updated the network structures in ESRGAN and BasicSR repo. You can use this script to convert the models. Or you can download the previous releases.

WANG-1173 commented 5 years ago

Hi I have the same problem as you. How do you solve it?

yja1 commented 5 years ago

change RRDBNet_arch.py for yourself model: import functools import torch import torch.nn as nn import torch.nn.functional as F import math import block as B

def make_layer(block, nlayers): layers = [] for in range(n_layers): layers.append(block()) return nn.Sequential(*layers)

class ResidualDenseBlock_5C(nn.Module): def init(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).init()

gc: growth channel, i.e. intermediate channels

    self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
    self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
    self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
    self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
    self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
    self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    # initialization
    # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

def forward(self, x):
    x1 = self.lrelu(self.conv1(x))
    x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
    x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
    x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
    x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
    return x5 * 0.2 + x

class RRDB(nn.Module): '''Residual in Residual Dense Block'''

def __init__(self, nf, gc=32):
    super(RRDB, self).__init__()
    self.RDB1 = ResidualDenseBlock_5C(nf, gc)
    self.RDB2 = ResidualDenseBlock_5C(nf, gc)
    self.RDB3 = ResidualDenseBlock_5C(nf, gc)

def forward(self, x):
    out = self.RDB1(x)
    out = self.RDB2(out)
    out = self.RDB3(out)
    return out * 0.2 + x

class RRDBNet(nn.Module): def init(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \ act_type='leakyrelu', mode='CNA', upsample_mode='upconv'): super(RRDBNet, self).init() n_upscale = int(math.log(upscale, 2)) if upscale == 3: n_upscale = 1

    fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
    rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
        norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
    LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

    if upsample_mode == 'upconv':
        upsample_block = B.upconv_blcok
    elif upsample_mode == 'pixelshuffle':
        upsample_block = B.pixelshuffle_block
    else:
        raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
    if upscale == 3:
        upsampler = upsample_block(nf, nf, 3, act_type=act_type)
    else:
        upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
    HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
    HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)

    self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
        *upsampler, HR_conv0, HR_conv1)

def forward(self, x):
    x = self.model(x)
    return x

in test.py: import RRDBNet_arch as arch

model_path = 'models/RRDB_ESRGAN_x4.pth'
model_path = '/BasicSR/experiments/002_RRDB_ESRGAN_x4_DIV2K/models/100000_G.pth' device = torch.device('cuda')

model = arch.RRDBNet(3, 3, 64, 23, gc=32) model.load_state_dict(torch.load(model_path), strict=True) model.eval() model = model.to(device)

WANG-1173 commented 5 years ago

Thank u very much!