satoshiiizuka / siggraphasia2019_remastering

Code for the paper "DeepRemaster: Temporal Source-Reference Attention Networks for Comprehensive Video Enhancement". http://iizuka.cs.tsukuba.ac.jp/projects/remastering/
Other
486 stars 103 forks source link

Can it work with a single image? #17

Open aligoglos opened 3 years ago

aligoglos commented 3 years ago

I wrote simple code to run model on a single image but result is gray still !! minimal demo :

 import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import argparse
import subprocess
import utils
import glob

def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    refimgs = None
    disable_colorization = False
    # Load remaster network
    modelR = __import__( 'model.remasternet', fromlist=['NetworkR'] ).NetworkR()
    state_dict = torch.load( 'remasternet.pth' )
    modelR.load_state_dict( state_dict['modelR'] )
    modelR = modelR.to(device)
    modelR.eval()
    if not disable_colorization:
        modelC = __import__( 'model.remasternet', fromlist=['NetworkC'] ).NetworkC()
        modelC.load_state_dict( state_dict['modelC'] )
        modelC = modelC.to(device)
        modelC.eval()
    paths = sorted(glob.glob('./inputs' + '/*'))
    for path in paths:
        image = cv2.imread(path)
        if ~(image is None):
            name = path.split('\\')[-1]
            print(name)
            refimgs = cv2.imread(F"./references/{name}")
            with torch.no_grad():
                gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
                frame_l = torch.from_numpy(gray).view( gray.shape[0], gray.shape[1], 1 )
                frame_l = frame_l.permute(2, 0, 1).float() # HWC to CHW
                frame_l /= 255.
                frame_l = frame_l.view(1, frame_l.size(0), 1, frame_l.size(1), frame_l.size(2))
                input = frame_l.to( device )
                output_l = modelR( input )
                if refimgs is None:
                    output_ab = modelC( output_l )
                else:
                    refimgs = torch.from_numpy(refimgs)
                    refimgs = refimgs.permute(2, 0, 1).float().unsqueeze(axis = 0).unsqueeze(axis = 0)
                    refimgs /= 255.
                    refimgs = refimgs.to( device )
                    output_ab = modelC( output_l, refimgs )

                output_l = output_l.detach().cpu()
                output_ab = output_ab.detach().cpu()
                out_l = output_l[0,:,0,:,:]
                out_c = output_ab[0,:,0,:,:]
                output = torch.cat((out_l, out_c), dim=0).numpy().transpose((1, 2, 0))
                output = Image.fromarray( np.uint8( utils.convertLAB2RGB( output )*255 ) )
                output.save( F"./results/{name}" )

if __name__ == "__main__":
    main()

input image : 1

out put : 1

** Note : reference image is equal to input

zhaoyuzhi commented 3 years ago

I also encounter the same issue (output grayscale when single frame as input). Have you addressed this issue?

hermosayhl commented 3 years ago

I encounter this issue, too. Is there anyone make it ?

hermosayhl commented 3 years ago

I also encounter the same issue (output grayscale when single frame as input). Have you addressed this issue?

Dr zhao, have you overcome this problem?

Dawars commented 8 months ago

You have to emulate multiple frames by duplicating the image to make the temporal convolutions work:

input = torch.tile(input, (1, 1, 5, 1, 1))

The network still isn't able to use colors from the reference images if they are significantly different from the gray image.