XingangPan / GAN2Shape

Code for GAN2Shape (ICLR2021 oral)
https://arxiv.org/abs/2011.00844
MIT License
571 stars 104 forks source link

rendering image differs from the input one #5

Closed jackytu256 closed 3 years ago

jackytu256 commented 3 years ago

Hi Thanks for releasing this repo.

I tried to do the inference to get the different degrees of viewpoints(celeba) by using the the output of netview and texture, followed by forward_step1 function, but the issue is that the generated image differs from the input one. Do I need to do the preprocessing for input image, such as specified alignment or the code I use needs to be modified? please provide some suggestions to help me to fix it. Thanks in advance

import torch 
from gan2shape import networks
from torchvision import transforms
from PIL import Image
from gan2shape.renderer.renderer_infer import Renderer
from torchvision.utils import save_image
import math
import torch.nn as nn
import torch.nn.functional as F
from dlib_align import align
import cv2

torch.manual_seed(0)
pretrain_pth_path = "/workspace/GAN2Shape/checkpoints/gan2shape/celeba_pretrain.pth"
image_path = "/workspace/GAN2Shape/data/celeba/000000.png"
image_name = image_path.split("/")[-1]
image_size = 128
ckpt = torch.load(pretrain_pth_path, map_location="cpu")
max_depth = 1.1
min_depth = 0.9
gn_base = 8 if image_size >= 128 else 16
nf = max(4096 // image_size, 16)
flip1 = [False, True, True, True]
flip3 = [True, True, True, True]
mode = 'step3' 
depth_rescaler = lambda d: (1+d)/2 *max_depth + (1-d)/2 *min_depth
add_mean_V = True
xyz_rotation_range =60
xy_translation_range = 0.1
z_translation_range = 0
add_mean_L = True
use_mask = False
border_depth = (0.7*max_depth + 0.3*min_depth)

def init_VL_sampler():
        from torch.distributions.multivariate_normal import MultivariateNormal as MVN
        view_mvn_path = 'checkpoints/view_light/celeba_view_mvn.pth'
        light_mvn_path = 'checkpoints/view_light/celeba_light_mvn.pth'
        view_mvn = torch.load(view_mvn_path)
        light_mvn = torch.load(light_mvn_path)
        return view_mvn['mean'].cuda(),light_mvn['mean'].cuda(),MVN(view_mvn['mean'].cuda(), view_mvn['cov'].cuda()),MVN(light_mvn['mean'].cuda(), light_mvn['cov'].cuda())

netDepth = networks.EDDeconv(cin=3, cout=1, size=image_size, nf=nf, gn_base=gn_base, zdim=256, activation=None)
netDepth.load_state_dict(ckpt["netD"])
netDepth.cuda()
netDepth.eval()
netalbedo = networks.EDDeconv(cin=3, cout=3, size=image_size, nf=nf, gn_base=gn_base, zdim=256)
netalbedo.load_state_dict(ckpt["netA"])
netalbedo.cuda()
netalbedo.eval()
netView = networks.Encoder(cin=3, cout=6, size=image_size, nf=nf)
netView.load_state_dict(ckpt["netV"])
netView.cuda()
netView.eval()
netLight = networks.Encoder(cin=3, cout=4, size=image_size, nf=nf)
netLight.load_state_dict(ckpt["netL"])
netLight.cuda()
netLight.eval()
# renderer = renderer_infer
renderer = Renderer()
view_mean,light_mean,view_mvn,light_mvn = init_VL_sampler()

transform = transforms.Compose(
            [
                transforms.Resize(image_size),
                transforms.ToTensor(),
            ]
        )
image = Image.open(image_path)
image = transform(image).unsqueeze(0).cuda()
image = image * 2 - 1
image.cuda()

b = 1
h, w = image_size, image_size
## predict depth
depth_raw = netDepth(image).squeeze(1)  # 1xHxW
depth = depth_raw - depth_raw.view(1,-1).mean(1).view(1,1,1)
depth = depth.tanh()
depth = depth_rescaler(depth)
depth_border = torch.zeros(1,h,w-4).cuda()
depth_border = nn.functional.pad(depth_border, (2,2), mode='constant', value=1.02)
depth = depth*(1-depth_border) + depth_border *border_depth
if (flip3 and mode == 'step3') or flip1:
        depth = torch.cat([depth, depth.flip(2)], 0)

## predict viewpoint transformation
view = netView(image)
if add_mean_V:
    view = view + view_mean.unsqueeze(0)

view_trans = torch.cat([
    view[:,:3] *math.pi/180 * xyz_rotation_range,
    view[:,3:5] * xy_translation_range,
    view[:,5:] *z_translation_range], 1)

if flip3 and mode == 'step3':
    view_trans = view.repeat(2,1)
renderer.set_transform_matrices(view_trans)

## predict albedo
albedo = netalbedo(image)  # 1x3xHxW
if (flip3 and mode == 'step3') or flip1:
    albedo = torch.cat([albedo, albedo.flip(3)], 0)  # flip

## predict lighting
light = netLight(image)  # Bx4
if add_mean_L:
   light = light + light_mean.unsqueeze(0)
if (flip3 and mode == 'step3') or flip1:
   light = light.repeat(2,1)  # Bx4
light_a = light[:,:1] /2+0.5  # ambience term
light_b = light[:,1:2] /2+0.5  # diffuse term
light_dxy = light[:,2:]
light_d = torch.cat([light_dxy, torch.ones(light.size(0),1).cuda()], 1)
light_d = light_d / ((light_d**2).sum(1, keepdim=True))**0.5  # diffuse light direction

## shading
normal = renderer.get_normal_from_depth(depth)
diffuse_shading = (normal * light_d.view(-1,1,1,3)).sum(3).clamp(min=0).unsqueeze(1)
shading = light_a.view(-1,1,1,1) + light_b.view(-1,1,1,1)*diffuse_shading
texture = (albedo/2+0.5) * shading *2-1

recon_depth = renderer.warp_canon_depth(depth)
recon_normal = renderer.get_normal_from_depth(recon_depth)
save_image(recon_depth, f'./recon_depth.png')

grid_2d_from_canon = renderer.get_inv_warped_2d_grid(recon_depth)
margin = (max_depth - min_depth) /2
recon_im_mask = (recon_depth < max_depth+margin).float()  # invalid border pixels have been clamped at max_depth+margin
if (flip3 and mode == 'step3') or flip1:
        recon_im_mask = recon_im_mask[:b] * recon_im_mask[b:]
        recon_im_mask = recon_im_mask.repeat(2,1,1)
recon_im_mask = recon_im_mask.unsqueeze(1).detach()
recon_im = nn.functional.grid_sample(texture, grid_2d_from_canon, mode='bilinear').clamp(min=-1, max=1)
print("recon_im:",recon_im.shape)
save_image(recon_im[0], f'./recon_img.png', nrow=1)
save_image(recon_im[1], f'./recon_img1.png', nrow=1)
with torch.no_grad():
    _depth, _texture, _view = depth[0,None], texture[0,None], view_trans[0,None]
    num_p, num_y = 5, 9  # number of pitch and yaw angles to sample
    max_y = 70
    maxr = [20, max_y]
    # sample viewpoints
    im_rotate = renderer.render_view(_texture, _depth, maxr=maxr, nsample=[num_p,num_y])[0]
    im_rotate = im_rotate/2+0.5
    for i in range(im_rotate.size(0)):
                    save_image(im_rotate[i,None], f'./{image_name}_im_rotate_stage3_{i:03}.png', nrow=1)
XingangPan commented 3 years ago

@jackytu256 (1) The save_image function would revise recon_depth, which would cause problem, so you may use save_image(recon_depth.clone(), f'./recon_depth.png') (2) Use view_trans = view_trans.repeat(2,1) instead of view_trans = view.repeat(2,1). (3) Normalize recon_im before saving it, e.g.: save_image(recon_im[0]/2+0.5, f'./recon_img.png', nrow=1) I believe these revisions should solve the problem.

jackytu256 commented 3 years ago

@XingangPan thanks for your reply. (1) the save_image function I use is torchvision.utils(just for your reference). (2) done (3) done

seems that the problem still remains. Do I need to provide extra information for you?

really appreciate your help :)

Thanks

XingangPan commented 3 years ago

@jackytu256 I have run your code and it works well with the revisions. How is your result wrong? Could you provide an example? Here are my results: Input: 000000 Results: 000000 png_im_rotate_stage3_002 000000 png_im_rotate_stage3_004 000000 png_im_rotate_stage3_006

jackytu256 commented 3 years ago

here is the example.

I reckon unseen images may occur this problem(got from ffhq) ?

please let me know if I need to provide more information. Thanks

example

here is aligned images 00995 png-1_align_image 00996 png-1_align_image 00997 png-1_align_image 00998 png-1_align_image 00999 png-1_align_image

XingangPan commented 3 years ago

@jackytu256 I see. This is not a bug. I think there are two main reasons. (1) Note that the pre-trained model is not to predict accurate results for unseen samples, but to provide a rough initialization (as it is pre-trained on only 200 images). You still need to run instance-specific training as in scripts/run_celeba.sh. (2) There is a domain gap between FFHQ (testing) and CelebA (training). Faces are aligned differently for the two datasets. I think the issue would be alleviated if you test with CelebA images cropped via https://github.com/elliottwu/unsup3d/tree/master/data.

jackytu256 commented 3 years ago

thanks @XingangPan

kashishnaqvi101 commented 12 months ago

thanks @XingangPan

Hey, were you able to find good results?