Closed Dejan1969 closed 4 years ago
You may want to implement some kind of spliting and joining Here is my sollution...
tiletool.py
import numpy as np class Tiletool: def __init__(self, upscale_factor, tile_size, padding_size): # upscale_factor : INT self.upscale_factor = upscale_factor # tile_size : INT size of the patches from the original image (without padding) self.tile_size = tile_size # padding_size : INT size of the overlapping area self.padding_size = padding_size # low resolution image shape self.lr_image_shape = [] # padded image shape self.p_shape = [] def split_image(self, image_to_split): """ Splits the image into partially overlapping patches. image_to_split: numpy array of the input image. """ self.lr_image_shape = image_to_split.shape x_remainder = self.lr_image_shape[0] % self.tile_size y_remainder = self.lr_image_shape[1] % self.tile_size # modulo here is to avoid extending of patch_size instead of 0 x_extend = (self.tile_size - x_remainder) % self.tile_size y_extend = (self.tile_size - y_remainder) % self.tile_size # make sure the image is divisible into regular patches extended_image = np.pad(image_to_split, ((0, x_extend), (0, y_extend), (0, 0)), 'edge') # add padding around the image to simplify computations padded_image = np.pad(extended_image, ((self.padding_size, self.padding_size), (self.padding_size, self.padding_size), (0, 0)), 'edge', ) self.p_shape = padded_image.shape patches = [] x_lefts = range(self.padding_size, self.p_shape[0] - self.padding_size, self.tile_size) y_tops = range(self.padding_size, self.p_shape[1] - self.padding_size, self.tile_size) for x in x_lefts: for y in y_tops: x_left = x - self.padding_size y_top = y - self.padding_size x_right = x + self.tile_size + self.padding_size y_bottom = y + self.tile_size + self.padding_size patch = padded_image[x_left:x_right, y_top:y_bottom, :] patches.append(patch) return patches def join_image(self, tile_list): """ Reconstruct the image from overlapping patches. After scaling, shapes and padding should be scaled too. tile_list: list of patches collected after upscaling """ # convert collect list to ndarray image_patches = np.array(tile_list) padded_size_scaled = tuple(np.multiply(self.p_shape[0:2], self.upscale_factor)) + (3,) scaled_image_shape = tuple(np.multiply(self.lr_image_shape[0:2], self.upscale_factor)) + (3,) xmax, ymax, _ = padded_size_scaled padding_size = self.padding_size * self.upscale_factor image_patches = image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :] patch_size = image_patches.shape[1] n_patches_per_row = ymax // patch_size complete_image = np.zeros((xmax, ymax, 3)) row = -1 col = 0 for i in range(len(image_patches)): if i % n_patches_per_row == 0: row += 1 col = 0 complete_image[ row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, : ] = image_patches[i] col += 1 joined_image = complete_image[0: scaled_image_shape[0], 0: scaled_image_shape[1], :] return joined_image
This is partial implementation from https://github.com/idealo/image-super-resolution thanks to Francesco Cardinale, github: cfrancesco
And calling from test.py
import os import glob import cv2 import numpy as np import torch import RRDBNet_arch as Arch from tiletool import Tiletool class ESRGAN: def __init__(self, model_path, device, scale_factor=4, tile_size=512, padding_size=4): self.padding_size = padding_size self.scale_factor = scale_factor self.tile_size = tile_size model = Arch.RRDBNet(3, 3, 64, 23, gc=32) model.load_state_dict(torch.load(model_path), strict=True) model.eval() for _, v in model.named_parameters(): v.requires_grad = False self.model = model.to(device) self.device = device def upscale(self, img): img = img * 1.0 / 255 img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() img_lr = img.unsqueeze(0).to(self.device) output = self.model(img_lr).data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) output = (output * 255.0).round() return output def process(self, input_path, output_path): # read input image lr_img = cv2.imread(input_path, cv2.IMREAD_COLOR) collect_hr_tile = [] # process small images directly without the use of tiles if self.tile_size > 0 and lr_img.shape[0] <= self.tile_size and lr_img.shape[1] <= self.tile_size: output_img = self.upscale(lr_img) else: # Splits the image into partially overlapping patches tt = Tiletool(self.scale_factor, self.tile_size, self.padding_size) patches = tt.split_image(lr_img) for imageslice in range(0, len(patches)): # upscale tile hr_tile = self.upscale(patches[imageslice]) collect_hr_tile.append(hr_tile) # Reconstruct the image from overlapping patches output_img = tt.join_image(collect_hr_tile) cv2.imwrite(output_path, output_img) def main(): # Do not change scale_factor scale_factor = 4 # Reduce tile_size(>= 32) and padding_size(>= 2) if OUT OF MEMORY tile_size = 512 padding_size = 4 # select GPU or device = torch.device('cpu') and go out for lunch... device = torch.device('cuda:0') # scale down images for testing using bicubic # ImageMagick work fine: # $ convert hires.jpg -interpolate BiCubic -interpolative-resize 25% lowres_for_testing.jpg # and put them in LR input_folder = 'LR/*' output_folder = 'results/' # select model RRDB_ESRGAN_x4.pth or RRDB_PSNR_x4.pth model_path = 'models/RRDB_ESRGAN_x4.pth' print("Initializing ESRGAN using model '%s'" % os.path.basename(model_path), flush=True) esrgan = ESRGAN(model_path, device, scale_factor, tile_size, padding_size) for input_path in glob.glob(input_folder): input_name = os.path.basename(input_path) print('Upscaling', input_name, flush=True) input_name = os.path.splitext(input_name)[0] output_path = os.path.join(output_folder, input_name + '_' + os.path.basename(model_path) + '.png') esrgan.process(input_path, output_path) if __name__ == '__main__': exit(main())
You may want to implement some kind of spliting and joining Here is my sollution...
tiletool.py
This is partial implementation from https://github.com/idealo/image-super-resolution thanks to Francesco Cardinale, github: cfrancesco
And calling from test.py