xinntao / ESRGAN

ECCV18 Workshops - Enhanced SRGAN. Champion PIRM Challenge on Perceptual Super-Resolution. The training codes are in BasicSR.
https://github.com/xinntao/BasicSR
Apache License 2.0
5.97k stars 1.06k forks source link

Suggestion for testing hires images #97

Closed Dejan1969 closed 4 years ago

Dejan1969 commented 4 years ago

You may want to implement some kind of spliting and joining Here is my sollution...

tiletool.py

import numpy as np

class Tiletool:

    def __init__(self, upscale_factor, tile_size, padding_size):

        # upscale_factor : INT
        self.upscale_factor = upscale_factor

        # tile_size : INT size of the patches from the original image (without padding)
        self.tile_size = tile_size

        # padding_size : INT size of the overlapping area
        self.padding_size = padding_size

        # low resolution image shape
        self.lr_image_shape = []

        # padded image shape
        self.p_shape = []

    def split_image(self, image_to_split):
        """ Splits the image into partially overlapping patches.
            image_to_split: numpy array of the input image.
        """

        self.lr_image_shape = image_to_split.shape

        x_remainder = self.lr_image_shape[0] % self.tile_size
        y_remainder = self.lr_image_shape[1] % self.tile_size

        # modulo here is to avoid extending of patch_size instead of 0
        x_extend = (self.tile_size - x_remainder) % self.tile_size
        y_extend = (self.tile_size - y_remainder) % self.tile_size

        # make sure the image is divisible into regular patches
        extended_image = np.pad(image_to_split, ((0, x_extend), (0, y_extend), (0, 0)), 'edge')

        # add padding around the image to simplify computations
        padded_image = np.pad(extended_image, ((self.padding_size, self.padding_size), (self.padding_size,
                                                                                        self.padding_size), (0, 0)), 'edge', )
        self.p_shape = padded_image.shape
        patches = []

        x_lefts = range(self.padding_size, self.p_shape[0] - self.padding_size, self.tile_size)
        y_tops = range(self.padding_size, self.p_shape[1] - self.padding_size, self.tile_size)

        for x in x_lefts:
            for y in y_tops:
                x_left = x - self.padding_size
                y_top = y - self.padding_size
                x_right = x + self.tile_size + self.padding_size
                y_bottom = y + self.tile_size + self.padding_size
                patch = padded_image[x_left:x_right, y_top:y_bottom, :]
                patches.append(patch)

        return patches

    def join_image(self, tile_list):

        """ Reconstruct the image from overlapping patches.
        After scaling, shapes and padding should be scaled too.
            tile_list: list of patches collected after upscaling
        """
        # convert collect list to ndarray
        image_patches = np.array(tile_list)

        padded_size_scaled = tuple(np.multiply(self.p_shape[0:2], self.upscale_factor)) + (3,)
        scaled_image_shape = tuple(np.multiply(self.lr_image_shape[0:2], self.upscale_factor)) + (3,)

        xmax, ymax, _ = padded_size_scaled
        padding_size = self.padding_size * self.upscale_factor

        image_patches = image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :]

        patch_size = image_patches.shape[1]
        n_patches_per_row = ymax // patch_size
        complete_image = np.zeros((xmax, ymax, 3))

        row = -1
        col = 0
        for i in range(len(image_patches)):
            if i % n_patches_per_row == 0:
                row += 1
                col = 0
            complete_image[
            row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, :
            ] = image_patches[i]
            col += 1

        joined_image = complete_image[0: scaled_image_shape[0], 0: scaled_image_shape[1], :]

        return joined_image

This is partial implementation from https://github.com/idealo/image-super-resolution thanks to Francesco Cardinale, github: cfrancesco

And calling from test.py

import os
import glob
import cv2
import numpy as np
import torch
import RRDBNet_arch as Arch
from tiletool import Tiletool

class ESRGAN:

    def __init__(self, model_path, device, scale_factor=4, tile_size=512, padding_size=4):
        self.padding_size = padding_size
        self.scale_factor = scale_factor
        self.tile_size = tile_size

        model = Arch.RRDBNet(3, 3, 64, 23, gc=32)

        model.load_state_dict(torch.load(model_path), strict=True)
        model.eval()

        for _, v in model.named_parameters():
            v.requires_grad = False

        self.model = model.to(device)
        self.device = device

    def upscale(self, img):
        img = img * 1.0 / 255
        img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
        img_lr = img.unsqueeze(0).to(self.device)

        output = self.model(img_lr).data.squeeze().float().cpu().clamp_(0, 1).numpy()

        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0).round()
        return output

    def process(self, input_path, output_path):
        # read input image
        lr_img = cv2.imread(input_path, cv2.IMREAD_COLOR)
        collect_hr_tile = []
        # process small images directly without the use of tiles
        if self.tile_size > 0 and lr_img.shape[0] <= self.tile_size and lr_img.shape[1] <= self.tile_size:
            output_img = self.upscale(lr_img)
        else:
            # Splits the image into partially overlapping patches
            tt = Tiletool(self.scale_factor, self.tile_size, self.padding_size)
            patches = tt.split_image(lr_img)
            for imageslice in range(0, len(patches)):
                # upscale tile
                hr_tile = self.upscale(patches[imageslice])
                collect_hr_tile.append(hr_tile)
            # Reconstruct the image from overlapping patches    
            output_img = tt.join_image(collect_hr_tile)
        cv2.imwrite(output_path, output_img)

def main():
    # Do not change scale_factor
    scale_factor = 4

    # Reduce tile_size(>= 32) and padding_size(>= 2) if OUT OF MEMORY
    tile_size = 512
    padding_size = 4

    # select GPU or device = torch.device('cpu') and go out for lunch...
    device = torch.device('cuda:0')

    # scale down images for testing using bicubic
    # ImageMagick work fine:
    # $ convert hires.jpg -interpolate BiCubic -interpolative-resize 25% lowres_for_testing.jpg
    # and put them in LR
    input_folder = 'LR/*'
    output_folder = 'results/'

    # select model RRDB_ESRGAN_x4.pth or RRDB_PSNR_x4.pth
    model_path = 'models/RRDB_ESRGAN_x4.pth'

    print("Initializing ESRGAN using model '%s'" % os.path.basename(model_path), flush=True)

    esrgan = ESRGAN(model_path, device, scale_factor, tile_size, padding_size)

    for input_path in glob.glob(input_folder):
        input_name = os.path.basename(input_path)
        print('Upscaling', input_name, flush=True)
        input_name = os.path.splitext(input_name)[0]
        output_path = os.path.join(output_folder, input_name + '_' + os.path.basename(model_path) + '.png')
        esrgan.process(input_path, output_path)

if __name__ == '__main__':
    exit(main())