switchablenorms / CelebAMask-HQ

A large-scale face dataset for face parsing, recognition, generation and editing.
2.05k stars 343 forks source link

You can input single image, and generate new mask. The code be based on the original 'tester.py' code. #68

Open lizhu8132 opened 3 years ago

lizhu8132 commented 3 years ago

import torch.nn as nn from torchvision.utils import save_image from torchvision import transforms import cv2 import PIL from My_Detector.unet import unet from My_Detector.utils import * from PIL import Image import glob

def transformer(resize, totensor, normalize, centercrop, imsize): options = [] if centercrop: options.append(transforms.CenterCrop(160)) if resize: options.append(transforms.Resize((imsize, imsize), interpolation=PIL.Image.NEAREST)) if totensor: options.append(transforms.ToTensor()) if normalize: options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) transform = transforms.Compose(options)

return transform

def trans_square(image): r"""Open the image using PIL.""" image = image.convert('RGB') w, h = image.size background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127)) length = int(abs(w - h) // 2) box = (length, 0) if w < h else (0, length) background.paste(image, box) return background

class My_Tester(object): def init(self): self.parallel = False self.test_label_path = './test_results' self.test_color_label_path = './test_color_visualize'

    self.build_model()

def test(self):
    image_size = 512
    transform = transformer(True, True, True, False, image_size)
    make_folder(self.test_label_path, '')
    make_folder(self.test_color_label_path, '')
    make_folder('./samples', '')
    self.G.load_state_dict(torch.load('./models/model.pth'))
    self.G.eval()

    path = glob.glob(r'./test_img/*.jpg')[0]
    img = Image.open(path)
    img = trans_square(img)
    new_img = img.resize((1024, 1024), PIL.Image.BILINEAR)

    img = transform(img)
    img = img.unsqueeze(0).cuda()

    labels_predict = self.G(img)

    labels_predict_plain = generate_label_plain(labels_predict, image_size)
    labels_predict_color = generate_label(labels_predict, image_size)

    cv2.imwrite(os.path.join(self.test_label_path, 'test.png'), labels_predict_plain[0])
    save_image(labels_predict_color[0], os.path.join(self.test_color_label_path, 'test.png'))

    new_img.save('./samples/Image.jpg')
    cv2.imwrite('./samples/Mask.png', labels_predict_plain[0])

def build_model(self):
    self.G = unet().cuda()
    if self.parallel:
        self.G = nn.DataParallel(self.G)

if name == 'main': t = My_Tester() t.test()

opptimus commented 3 years ago

can you release the .py files with Separate python file Separate python file Thanks! @helindemeng

jakeyahn commented 2 years ago

how did you connect this to main()

liudan193 commented 6 months ago

just create a .py file under "face_parsing", and copy the code below:

# https://github.com/switchablenorms/CelebAMask-HQ/issues/68

import os
import re
import torch.nn as nn
from torchvision.utils import save_image
from torchvision import transforms
import cv2
import PIL
from tqdm import tqdm
from unet import unet
from utils import *
from PIL import Image
import glob

def transformer(resize, totensor, normalize, centercrop, imsize):
    options = []
    if centercrop:
        options.append(transforms.CenterCrop(160))
    if resize:
        options.append(transforms.Resize((imsize, imsize), interpolation=PIL.Image.NEAREST))
    if totensor:
        options.append(transforms.ToTensor())
    if normalize:
        options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
    transform = transforms.Compose(options)

    return transform

def create_dir(filepath):
    directory = os.path.dirname(filepath)
    if not os.path.exists(directory):
        os.makedirs(directory)

def trans_square(image):
    r"""Open the image using PIL."""
    image = image.convert('RGB')
    w, h = image.size
    background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127))
    length = int(abs(w - h) // 2)
    box = (length, 0) if w < h else (0, length)
    background.paste(image, box)
    return background

class My_Tester(object):
    def __init__(self):
        self.parallel = False
        self.test_label_path = './test_results'
        self.test_color_label_path = './test_color_visualize'
        self.build_model()

    def test(self):
        image_size = 512
        transform = transformer(True, True, True, False, image_size)
        make_folder(self.test_label_path, '')
        make_folder(self.test_color_label_path, '')
        make_folder('./samples', '')
        self.G.load_state_dict(torch.load('./models/parsenet/model.pth'))
        self.G.eval()

        paths = glob.glob(r'xxx/*.jpg')  # your image folders
        for path in tqdm(paths):
            img = Image.open(path)
            img = trans_square(img)
            new_img = img.resize((1024, 1024), PIL.Image.BILINEAR)
            img = transform(img)
            img = img.unsqueeze(0).cuda()

            labels_predict = self.G(img)
            labels_predict_plain = generate_label_plain(labels_predict, image_size)
            labels_predict_color = generate_label(labels_predict, image_size)

            # original image
            save_path_image = path.replace(".jpg", "_resized.jpg")
            new_img.save(save_path_image)

            # mask
            save_path_mask = path.replace(".jpg", "_mask.png")
            create_dir(save_path_mask)
            cv2.imwrite(save_path_mask, labels_predict_plain[0])

            # mask color
            save_path_mask_color = path.replace(".jpg", "_mask_color.png")
            create_dir(save_path_mask_color)
            save_image(labels_predict_color[0], save_path_mask_color)

    def build_model(self):
        self.G = unet().cuda()
        if self.parallel:
            self.G = nn.DataParallel(self.G)

if __name__ == '__main__':
    t = My_Tester()
    t.test()