chaiNNer-org / spandrel

Spandrel gives your project support for various PyTorch architectures meant for AI Super-Resolution, restoration, and inpainting. Based on the model support implemented in chaiNNer.
MIT License
106 stars 7 forks source link

may I ask how to use the tiling function #113

Open whybfq opened 6 months ago

whybfq commented 6 months ago
from spandrel import ImageModelDescriptor, ModelLoader
import torch

device = torch.device("cuda:0")

# load a model from disk
model = ModelLoader().load_from_file(r"path/to/model.pth")

# make sure it's an image to image model
assert isinstance(model, ImageModelDescriptor)

# send it to the GPU and put it in inference mode
model.to(device)
model.eval()

# use the model
def process(image: torch.Tensor) -> torch.Tensor:
    with torch.no_grad():
        return model(image)

may I ask how to use the tiling function after load the model, such as ESRGAN, or do I need to realize by myself, such as for big picture, I need to tile it with 512 *512, tile_pad is 10, and pre_pad is 10

RunDevelopment commented 6 months ago

Spandrel currently doesn't tile the input image for you, so you have to implement tiling yourself. However, we plan on adding image tiling later.

whybfq commented 5 months ago
Tiling code ```py import math import cv2 import numpy as np import torch from torch.nn import functional as F from spandrel import ModelLoader, ImageModelDescriptor from util import (image_inference, get_h_w_c, ModelFile) MAX_FILE_SIZE = 15 * 1024 * 1024 # 设置最大文件大小 device = torch.device("cuda:0") class ModelTile(): """A helper class for upsampling images with RealESRGAN. Args: scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. model (nn.Module): The defined network. Default: None. tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop input images into tiles, and then process each of them. Finally, they will be merged into one image. 0 denotes for do not use tile. Default: 0. tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. half (float): Whether to use half precision during inference. Default: False. """ def __init__(self, scale, model=None, tile=0, tile_pad=10, pre_pad=10, device=None, gpu_id=0, half=False): # Error Input type (c10::Half) and bias type (float) should be the same self.scale = scale self.tile_size = tile self.tile_pad = tile_pad self.pre_pad = pre_pad self.mod_scale = None self.half = half # initialize model if gpu_id: self.device = torch.device( f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device else: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device model.eval() self.model = model.to(self.device) # if self.half and self.model.supports_half and self.model.supports_bfloat16: # self.model = self.model.half() def pre_process(self, img): """Pre-process, such as pre-pad and mod pad, so that the images can be divisible """ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() self.img = img.unsqueeze(0).to(self.device) # if self.half: # self.img = self.img.half() # pre_pad if self.pre_pad != 0: self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') # mod pad for divisible borders if self.scale == 2: self.mod_scale = 2 elif self.scale == 1: self.mod_scale = 4 if self.mod_scale is not None: self.mod_pad_h, self.mod_pad_w = 0, 0 _, _, h, w = self.img.size() if (h % self.mod_scale != 0): self.mod_pad_h = (self.mod_scale - h % self.mod_scale) if (w % self.mod_scale != 0): self.mod_pad_w = (self.mod_scale - w % self.mod_scale) self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') @torch.no_grad def process(self): # model inference if self.half: # Before passing inputs to the model, convert them to the same data type as the model. self.output = self.model.half()(self.img.half()) else: self.output = self.model(self.img) def tile_process(self): """It will first crop input images to tiles, and then process each tile. Finally, all the processed tiles are merged into one images. Modified from: https://github.com/ata4/esrgan-launcher """ batch, channel, height, width = self.img.shape output_height = height * self.scale output_width = width * self.scale output_shape = (batch, channel, output_height, output_width) # start with black image self.output = self.img.new_zeros(output_shape) tiles_x = math.ceil(width / self.tile_size) tiles_y = math.ceil(height / self.tile_size) # loop over all tiles for y in range(tiles_y): for x in range(tiles_x): # extract tile from input image ofs_x = x * self.tile_size ofs_y = y * self.tile_size # input tile area on total image input_start_x = ofs_x input_end_x = min(ofs_x + self.tile_size, width) input_start_y = ofs_y input_end_y = min(ofs_y + self.tile_size, height) # input tile area on total image with padding input_start_x_pad = max(input_start_x - self.tile_pad, 0) input_end_x_pad = min(input_end_x + self.tile_pad, width) input_start_y_pad = max(input_start_y - self.tile_pad, 0) input_end_y_pad = min(input_end_y + self.tile_pad, height) # input tile dimensions input_tile_width = input_end_x - input_start_x input_tile_height = input_end_y - input_start_y tile_idx = y * tiles_x + x + 1 input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] # upscale tile try: with torch.no_grad(): output_tile = self.model(input_tile) except RuntimeError as error: print('Error', error) print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') # output tile area on total image output_start_x = input_start_x * self.scale output_end_x = input_end_x * self.scale output_start_y = input_start_y * self.scale output_end_y = input_end_y * self.scale # output tile area without padding output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale output_end_x_tile = output_start_x_tile + input_tile_width * self.scale output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale output_end_y_tile = output_start_y_tile + input_tile_height * self.scale # put tile into output image self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile] def post_process(self): # remove extra pad if self.mod_scale is not None: _, _, h, w = self.output.size() self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] # remove prepad if self.pre_pad != 0: _, _, h, w = self.output.size() self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] return self.output @torch.no_grad() def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): h_input, w_input = img.shape[0:2] # # img: numpy img = img.astype(np.float32) if np.max(img) > 256: # 16-bit image max_range = 65535 print('\tInput is a 16-bit image') else: max_range = 255 img = img / max_range if len(img.shape) == 2: # gray image img_mode = 'L' img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) elif img.shape[2] == 4: # RGBA image with alpha channel img_mode = 'RGBA' alpha = img[:, :, 3] img = img[:, :, 0:3] img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if alpha_upsampler == 'realesrgan': alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) else: img_mode = 'RGB' img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # ------------------- process image (without the alpha channel) ------------------- # self.pre_process(img) if self.tile_size > 0: self.tile_process() else: self.process() output_img = self.post_process() # tensor to image # output = tensor_to_img(output_img) output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) if img_mode == 'L': output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) # ------------------- process the alpha channel if necessary ------------------- # if img_mode == 'RGBA': if alpha_upsampler == 'realesrgan': self.pre_process(alpha) if self.tile_size > 0: self.tile_process() else: self.process() output_alpha = self.post_process() output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) else: # use the cv2 resize for alpha channel h, w = alpha.shape[0:2] output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) # merge the alpha channel output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) output_img[:, :, 3] = output_alpha # # # ------------------------------ return ------------------------------ # if max_range == 65535: # 16-bit image output = (output_img * 65535.0).round().astype(np.uint16) else: output = (output_img * 255.0).round().astype(np.uint8) if outscale is not None and outscale != float(self.scale): output = cv2.resize( output, ( int(w_input * outscale), int(h_input * outscale), ), interpolation=cv2.INTER_LANCZOS4) return output, img_mode def main(): # loaded_model = ModelLoader(device).load_from_file(r"/Documents/GitHub/esrgan/model/4x-UltraSharp") loaded_model = ModelFile("/Documents/GitHub/esrgan/model/net_g_50000.pth").load_model() assert isinstance(loaded_model, ImageModelDescriptor) # loaded_model = loaded_model.to(device).eval() # send it to the GPU and put it in inference mode # check model info print("Model:", loaded_model.model) print("Architecture:", loaded_model.architecture) print("purpose:", loaded_model.purpose) print("Tags:", loaded_model.tags) print("Supports half precision (fp16):", loaded_model.supports_half) print("Supports bfloat16 precision:", loaded_model.supports_bfloat16) print("Scale:", loaded_model.scale) print("Number of input channels:", loaded_model.input_channels) print("Number of output channels:", loaded_model.output_channels) print("Size requirements:", loaded_model.size_requirements) print("Tiling:", loaded_model.tiling) print("") cv_img = cv2.imread("inputs/8x8.png", cv2.IMREAD_COLOR) image_h, image_w, image_c = get_h_w_c(cv_img) if loaded_model.input_channels == 1 and image_c == 3: cv_img = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) image_c = 1 assert ( image_c == loaded_model.input_channels ), f"Expected the input image '{cv_img.value}' to have {loaded_model.input_channels} channels, but it had {image_c} channels." # output, img_mode = bg_upsampler.enhance(cv_img, outscale=loaded_model.scale) # RuntimeError: Input and output sizes should be greater than 0, but got input (H: 1, W: 1) output (H: 0, W: 0) for GFPGAN # output = tensor_to_img(loaded_model(img_to_tensor(cv_img))) output = image_inference(loaded_model, cv_img) cv2.imwrite("output.jpg", output) if __name__ == "__main__": main() ```

just add tiling support class like above as a reference, above code already passed the test on Ubuntu22.04

joeyballentine commented 5 months ago

Thanks, but we will most likely be taking our tiling code from chaiNNer (which we are also the main developers on) and extracting it out into its own package (that spandrel will then use). chaiNNer's code has a lot of nice things in it that have been built up over years, so we'd like to retain all that functionality. This doesn't look like a bad implementation, it's just missing some things.

Sorry for not being clear about that before. It wasn't that we don't have tiling code to use, it's just that we haven't worked on making a package with it yet as we have been focused on other things. But again, thanks for being willing to share this with us :)

arenasys commented 3 months ago

Would be nice to have this