avalonstrel / GatedConvolution_pytorch

A modified reimplemented in pytorch of inpainting model in Free-Form Image Inpainting with Gated Convolution [http://jiahuiyu.com/deepfill2/]
Other
434 stars 77 forks source link

Result wrong #30

Closed sjf18 closed 4 years ago

sjf18 commented 4 years ago

Hi , thanks for your great work, i write a demo using your model to predict images, but it seems something wrong with the result, like this: why is refined output gray? 1803151818-00000003

here is my demo code, could you please help me?

model_path = './model_logs/offical/latest_ckpt.pth.tar'
nets = torch.load(model_path)
netG_state_dict, netD_state_dict = nets['netG_state_dict'], nets['netD_state_dict']
netG = InpaintSANet()
load_consistent_state_dict(netG_state_dict, netG)
netG.to(cpu0)
netG.eval()
torch.set_grad_enabled(False)
save_img_dir = 'results/'
os.makedirs(save_img_dir, exist_ok=True)
test_img_dir = 'testdata/'
imgs_list = os.listdir(test_img_dir)
input_shape = (256,256)

for imgname in tqdm(imgs_list):
    img = cv2.imread(test_img_dir + imgname)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w, c = img_rgb.shape
    img_resize = cv2.resize(img_rgb, input_shape)
    mask = random_ff_mask(input_shape)
    img_tensor = torch.from_numpy((img_resize.astype(np.float32)[np.newaxis, :, :, :])).permute(0, 3, 1, 2)
    mask_tensor = torch.from_numpy((mask.astype(np.float32)[np.newaxis, :, :, :])).permute(0, 3, 1, 2)
    used_img, used_mask = img_tensor.to(cpu0), mask_tensor.to(cpu0)
    used_img = (used_img / 127.5 - 1)
    corse_img, refine_img = netG(used_img, used_mask)
    ## network output
    cor_img = 127.5*(corse_img+1).permute(0, 2, 3, 1)
    ref_img = 127.5*(refine_img+1).permute(0, 2, 3, 1)
    cor_img_np = cor_img.data.numpy()[0]
    ref_img_np = ref_img.data.numpy()[0]
    ## complete output
    cor_complete_img = corse_img * used_mask + used_img * (1 - used_mask)
    ref_complete_img = refine_img * used_mask + used_img * (1 - used_mask)
    cor_complete_img = 127.5*(cor_complete_img+1).permute(0, 2, 3, 1)
    ref_complete_img = 127.5*(ref_complete_img+1).permute(0, 2, 3, 1)
    cor_complete_img_np = cor_complete_img.data.numpy()[0]
    ref_complete_img_np = ref_complete_img.data.numpy()[0]
    ## save images
    first = np.concatenate((img_resize, 255*np.concatenate((mask,)*3, -1)), 0)
    third = np.concatenate((ref_complete_img_np, cor_complete_img_np), 0)
    second = np.concatenate((ref_img_np, cor_img_np), 0)
    out_img = np.concatenate((first, second, third), 1)
    out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(save_img_dir + imgname, out_img)