yu45020 / Waifu2x

PyTorch on Super Resolution
GNU General Public License v3.0
261 stars 48 forks source link
pytorch super-resolution

Waifu2x

Re-implementation on the original waifu2x in PyTorch with additional super resolution models. This repo is mainly used to explore interesting super resolution models. User-friendly tools may not be available now ><.

Dependencies

Optinal: Nvidia GPU. Model inference (32 fp only) can run in cpu only.

What's New

How to Use

Compare the input image and upscaled image

from utils.prepare_images import *
from Models import *
from torchvision.utils import save_image
model_cran_v2 = CARN_V2(color_channels=3, mid_channels=64, conv=nn.Conv2d,
                        single_conv_size=3, single_conv_group=1,
                        scale=2, activation=nn.LeakyReLU(0.1),
                        SEBlock=True, repeat_blocks=3, atrous=(1, 1, 1))

model_cran_v2 = network_to_half(model_cran_v2)
checkpoint = "model_check_points/CRAN_V2/CARN_model_checkpoint.pt"
model_cran_v2.load_state_dict(torch.load(checkpoint, 'cpu'))
# if use GPU, then comment out the next line so it can use fp16. 
model_cran_v2 = model_cran_v2.float() 

demo_img = "input_image.png"
img = Image.open(demo_img).convert("RGB")

# origin
img_t = to_tensor(img).unsqueeze(0) 

# used to compare the origin
img = img.resize((img.size[0] // 2, img.size[1] // 2), Image.BICUBIC) 

# overlapping split
# if input image is too large, then split it into overlapped patches 
# details can be found at [here](https://github.com/nagadomi/waifu2x/issues/238)
img_splitter = ImageSplitter(seg_size=64, scale_factor=2, boarder_pad_size=3)
img_patches = img_splitter.split_img_tensor(img, scale_method=None, img_pad=0)
with torch.no_grad():
    out = [model_cran_v2(i) for i in img_patches]
img_upscale = img_splitter.merge_img_tensor(out)

final = torch.cat([img_t, img_upscale])
save_image(final, 'out.png', nrow=2)

Training

If possible, fp16 training is preferred because it is much faster with minimal quality decrease.

Sample training script is available in train.py, but you may need to change some liens.

Image Processing

Original images are all at least 3k x 3K. I downsample them by LANCZOS so that one side has at most 2048, then I randomly cut them into 256x256 patches as target and use 128x128 with jpeg noise as input images. All input patches have at least 14 kb, and they are stored in SQLite with BLOB format. SQlite seems to have better performance than file system for small objects. H5 file format may not be optimal because of its larger size.

Although convolutions can take in any sizes of images, the content of image matters. For real life images, small patches may maintain color,brightness, etc variances in small regions, but for digital drawn images, colors are added in block areas. A small patch may end up showing entirely one color, and the model has little to learn.

For example, the following two plots come from CARN and have the same settings, including initial parameters. Both training loss and ssim are lower for 64x64, but they perform worse in test time compared to 128x128.

loss ssim

Downsampling methods are uniformly chosen among [PIL.Image.BILINEAR, PIL.Image.BICUBIC, PIL.Image.LANCZOS] , so different patches in the same image might be down-scaled in different ways.

Image noise are from JPEG format only. They are added by re-encoding PNG images into PIL's JPEG data with various quality. Noise level 1 means quality ranges uniformly from [75, 95]; level 2 means quality ranges uniformly from [50, 75].

Models

Models are tuned and modified with extra features.

img

img

Scores

The list will be updated after I add more models.

Images are twitter icons (PNG) from Key: サマボケ(Summer Pocket). They are cropped into non-overlapping 96x96 patches and down-scaled by 2. Then images are re-encoded into JPEG format with quality from [75, 95]. Scores are PSNR and MS-SSIM.

Total Parameters BICUBIC Random*
CRAN V2 2,149,607 34.0985 (0.9924) 34.0509 (0.9922)
DCSCN 12 1,889,974 31.5358 (0.9851) 31.1457 (0.9834)
Upconv 7 552,480 31.4566 (0.9788) 30.9492 (0.9772)

*uniformly select down scale methods from Image.BICUBIC, Image.BILINEAR, Image.LANCZOS.

DCSCN

Fast and Accurate Image Super Resolution by Deep CNN with Skip Connection and Network in Network

DCSCN is very interesting as it has relatively quick forward computation, and both the shallow model (layerr 8) and deep model (layer 12) are quick to train. The settings are different from the paper.

Another more effective model is to add upscaled input image to the final convolution. A simple bilinear upscaled image seems sufficient.

More examples on model configurations can be found in docs/CARN folder

img

img

Waifu2x Original Models

Models can load waifu2x's pre-trained weights. The function forward_checkpoint sets the nn.LeakyReLU to compute data inplace.

Upconv_7

Original waifu2x's model. PyTorch's implementation with cpu only is around 5 times longer for large images. The output images have very close PSNR and SSIM scores compared to images generated from the caffe version , thought they are not identical.

Vgg_7

Not tested yet, but it is ready to use.