xinntao / ESRGAN

ECCV18 Workshops - Enhanced SRGAN. Champion PIRM Challenge on Perceptual Super-Resolution. The training codes are in BasicSR.
https://github.com/xinntao/BasicSR
Apache License 2.0
5.91k stars 1.05k forks source link

ESRGAN + VapourSynth? #61

Closed rlaphoenix closed 5 years ago

rlaphoenix commented 5 years ago
# --------------------------------------------------
# ESRGAN
# --------------------------------------------------
import torch
cuda = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import RRDBNet_arch as arch
model = arch.RRDBNet(3, 3, 64, 23, gc=32)
model.load_state_dict(torch.load(r"H:\2; Models\ad 160k_tf.pth"), strict=True)
model.eval()
model = model.to(cuda)

import cv2
import numpy as np
with torch.no_grad():
    f = cv2.imread(??frame??)?
    c = model(torch.from_numpy(np.transpose((f * 1.0 / 255)[:, :, [2, 1, 0]], (2, 0, 1))).float().unsqueeze(0).to(cuda)).data.squeeze().float().cpu().clamp_(0, 1).numpy()
c = (np.transpose(c[[2, 1, 0], :, :], (1, 2, 0)) * 255.0).round()

This is my current attempt, The cv2.imread() bit is what im confused on, torch clearly expects a certain format which im unaware of being able to do with ESRGAN. Im also unsure how to get the specific frame being executed instead of accidentally doing it on EVERY FRAME, every time it does a frame.

rlaphoenix commented 5 years ago

Managed to create support! 👍 https://github.com/imPRAGMA/VSGAN