VITA-Group / DeblurGANv2

[ICCV 2019] "DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better" by Orest Kupyn, Tetiana Martyniuk, Junru Wu, Zhangyang Wang
Other
1.02k stars 265 forks source link

predict.py #144

Open CHSLAM opened 2 years ago

CHSLAM commented 2 years ago

I have downloaded fpn_inception.h5. Why I run python predict.py 0.png nothing happens?

ulucsahin commented 1 year ago

Because the code is not correct. in predict.py, replace if __name__ == "__main__" code block with:

if __name__ == "__main__":
    image_path = "path-to-your-image.jpg"
    main(image_path)
fighting666777 commented 1 year ago

I have downloaded fpn_inception.h5. Why I run python predict.py 0.png nothing happens?

Ihave the same problem,do you know how to slove it?

baselqt commented 9 months ago

here is my fix, place all jpg files you want to process in the test_img folder and run python predict.py only and modify the path\to\your\test_img with the actual path

import os
from glob import glob
from typing import Optional

import cv2
import numpy as np
import torch
import yaml
from tqdm import tqdm

from aug import get_normalize
from models.networks import get_generator

class Predictor:
    def __init__(self, weights_path: str, model_name: str = ''):
        with open('config/config.yaml', encoding='utf-8') as cfg:
            config = yaml.load(cfg, Loader=yaml.FullLoader)
        model = get_generator(model_name or config['model'])
        model.load_state_dict(torch.load(weights_path)['model'])
        self.model = model.cuda()
        self.model.train(True)
        self.normalize_fn = get_normalize()

    @staticmethod
    def _array_to_batch(x):
        x = np.transpose(x, (2, 0, 1))
        x = np.expand_dims(x, 0)
        return torch.from_numpy(x)

    def _preprocess(self, x: np.ndarray, mask: Optional[np.ndarray]):
        x, _ = self.normalize_fn(x, x)
        if mask is None:
            mask = np.ones_like(x, dtype=np.float32)
        else:
            mask = np.round(mask.astype('float32') / 255)

        h, w, _ = x.shape
        block_size = 32
        min_height = (h // block_size + 1) * block_size
        min_width = (w // block_size + 1) * block_size

        pad_params = {'mode': 'constant',
                      'constant_values': 0,
                      'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
                      }
        x = np.pad(x, **pad_params)
        mask = np.pad(mask, **pad_params)

        return map(self._array_to_batch, (x, mask)), h, w

    @staticmethod
    def _postprocess(x: torch.Tensor) -> np.ndarray:
        x, = x
        x = x.detach().cpu().float().numpy()
        x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
        return x.astype('uint8')

    def __call__(self, img: np.ndarray, mask: Optional[np.ndarray], ignore_mask=True) -> np.ndarray:
        (img, mask), h, w = self._preprocess(img, mask)
        with torch.no_grad():
            inputs = [img.cuda()]
            if not ignore_mask:
                inputs += [mask]
            pred = self.model(*inputs)
        return self._postprocess(pred)[:h, :w, :]

def main(img_pattern: str,
         weights_path='fpn_inception.h5',
         out_dir='submit/',
         side_by_side: bool = False):
    def sorted_glob(pattern):
        return sorted(glob(pattern))

    imgs = sorted_glob(img_pattern)
    names = sorted([os.path.basename(x) for x in glob(img_pattern)])
    predictor = Predictor(weights_path=weights_path)

    os.makedirs(out_dir, exist_ok=True)
    print(f"Total images to process: {len(names)}")

    for name, img_path in tqdm(zip(names, imgs), total=len(names)):
        print(f"Processing: {name}")
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        pred = predictor(img, None)
        if side_by_side:
            pred = np.hstack((img, pred))
        pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join(out_dir, name), pred)

if __name__ == '__main__':
    image_pattern = "PATH\TO\TEST_IMG"
    main(img_pattern=image_pattern)