sanghyun-son / EDSR-PyTorch

PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)
MIT License
2.43k stars 669 forks source link

poor performance when tested with the pretrained weight #241

Closed Ethean closed 4 years ago

Ethean commented 4 years ago

Hi, @thstkdgus35 thanks for your work! I tested on the set5 with provided pretrained model with setting

        n_resblocks = 32
        n_feats = 256
        kernel_size = 3
        scale = 4
        res_scale=0.1

and I rewrite the test code for simplicity.

def data_transfer():
    return Compose([ToTensor()])

def denormalize(img):
    img = img.mul(255.0).clamp(0.0, 255.0)
    return img

def test():
    edsr = EDSR()
    edsr.load_state_dict('pre_train_model/edsr_x4.pt')
    edsr.eval()
    lr_p = 'dataset/test_sr/set5l/'
    hr_p = 'dataset/test_sr/set5h/'
    lr_names = [join(lr_p, i) for i in listdir(lr_p) if is_image_file(i)]
    hr_names = [join(hr_p, i) for i in listdir(hr_p) if is_image_file(i)]
    lr_names.sort()
    hr_names.sort()
    psnr_total = 0
    ssim_total = 0
    save_path = 'test/edsr/'
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    for index in range(len(lr_names)):
        lr = Image.open(lr_names[index])
        hr = Image.open(hr_names[index])
        lr = data_transfer()(lr)
        hr = data_transfer()(hr)
        lr = lr.unsqueeze(0)
        hr = hr.unsqueeze(0)
        if torch.cuda.is_available():
            edsr = edsr.cuda()
            lr = lr.cuda()
            hr = hr.cuda()
        with torch.no_grad():
            sr = edsr(lr)
        mse = ((sr - hr) ** 2).data.mean()
        psnr = 10 * log10(1 / mse)
        ssim = SSIM()(sr, hr).data.item()
        psnr_total += psnr
        ssim_total += ssim
        output = Image.fromarray(denormalize(sr.squeeze(0)).permute(1, 2, 0).byte().cpu().numpy())
        output.save(save_path + str(index + 1) + ".png")
        print(index)
    print(psnr_total / 5, ssim_total / 5)

if __name__ == '__main__':
    test()

and I get poor results with psnr=19.22 and ssim=0.441, the visual results are follows output 1

input img_001_SRF_4_LR What did I miss? I cannot figure this out by myself. Any suggestions? Thanks in advance!

JJ-data-science commented 3 years ago

Hi @Ethean :)

I have exactly the same issue. How did you fix it ?