Open lizhu8132 opened 3 years ago
can you release the .py files with Separate python file Separate python file Thanks! @helindemeng
how did you connect this to main()
just create a .py file under "face_parsing", and copy the code below:
# https://github.com/switchablenorms/CelebAMask-HQ/issues/68
import os
import re
import torch.nn as nn
from torchvision.utils import save_image
from torchvision import transforms
import cv2
import PIL
from tqdm import tqdm
from unet import unet
from utils import *
from PIL import Image
import glob
def transformer(resize, totensor, normalize, centercrop, imsize):
options = []
if centercrop:
options.append(transforms.CenterCrop(160))
if resize:
options.append(transforms.Resize((imsize, imsize), interpolation=PIL.Image.NEAREST))
if totensor:
options.append(transforms.ToTensor())
if normalize:
options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform = transforms.Compose(options)
return transform
def create_dir(filepath):
directory = os.path.dirname(filepath)
if not os.path.exists(directory):
os.makedirs(directory)
def trans_square(image):
r"""Open the image using PIL."""
image = image.convert('RGB')
w, h = image.size
background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127))
length = int(abs(w - h) // 2)
box = (length, 0) if w < h else (0, length)
background.paste(image, box)
return background
class My_Tester(object):
def __init__(self):
self.parallel = False
self.test_label_path = './test_results'
self.test_color_label_path = './test_color_visualize'
self.build_model()
def test(self):
image_size = 512
transform = transformer(True, True, True, False, image_size)
make_folder(self.test_label_path, '')
make_folder(self.test_color_label_path, '')
make_folder('./samples', '')
self.G.load_state_dict(torch.load('./models/parsenet/model.pth'))
self.G.eval()
paths = glob.glob(r'xxx/*.jpg') # your image folders
for path in tqdm(paths):
img = Image.open(path)
img = trans_square(img)
new_img = img.resize((1024, 1024), PIL.Image.BILINEAR)
img = transform(img)
img = img.unsqueeze(0).cuda()
labels_predict = self.G(img)
labels_predict_plain = generate_label_plain(labels_predict, image_size)
labels_predict_color = generate_label(labels_predict, image_size)
# original image
save_path_image = path.replace(".jpg", "_resized.jpg")
new_img.save(save_path_image)
# mask
save_path_mask = path.replace(".jpg", "_mask.png")
create_dir(save_path_mask)
cv2.imwrite(save_path_mask, labels_predict_plain[0])
# mask color
save_path_mask_color = path.replace(".jpg", "_mask_color.png")
create_dir(save_path_mask_color)
save_image(labels_predict_color[0], save_path_mask_color)
def build_model(self):
self.G = unet().cuda()
if self.parallel:
self.G = nn.DataParallel(self.G)
if __name__ == '__main__':
t = My_Tester()
t.test()
import torch.nn as nn from torchvision.utils import save_image from torchvision import transforms import cv2 import PIL from My_Detector.unet import unet from My_Detector.utils import * from PIL import Image import glob
def transformer(resize, totensor, normalize, centercrop, imsize): options = [] if centercrop: options.append(transforms.CenterCrop(160)) if resize: options.append(transforms.Resize((imsize, imsize), interpolation=PIL.Image.NEAREST)) if totensor: options.append(transforms.ToTensor()) if normalize: options.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) transform = transforms.Compose(options)
def trans_square(image): r"""Open the image using PIL.""" image = image.convert('RGB') w, h = image.size background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127)) length = int(abs(w - h) // 2) box = (length, 0) if w < h else (0, length) background.paste(image, box) return background
class My_Tester(object): def init(self): self.parallel = False self.test_label_path = './test_results' self.test_color_label_path = './test_color_visualize'
if name == 'main': t = My_Tester() t.test()