yuval-alaluf / restyle-encoder

Official Implementation for "ReStyle: A Residual-Based StyleGAN Encoder via Iterative Refinement" (ICCV 2021) https://arxiv.org/abs/2104.02699
https://yuval-alaluf.github.io/restyle-encoder/
MIT License
1.03k stars 156 forks source link

Head cropping #23

Closed qo4on closed 3 years ago

qo4on commented 3 years ago

Hi! Is it possible to increase the size of the image in order to avoid head cropping? Untitled-2

yuval-alaluf commented 3 years ago

The cropping is part of the alignment step that is done according to the input expected by StyleGAN. You can try expanding the cropped area in the alignment step, but I am unsure how this will affect the results if they cropped area is too large compared to the data that StyleGAN was trained on.

qo4on commented 3 years ago

Do you know what the idea of their cropping code is? It looks simple, but I can't get why they use flipud and hypot.

# Choose oriented crop rectangle.
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
x /= np.hypot(*x)
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
y = np.flipud(x) * [-1, 1]
c = eye_avg + eye_to_mouth * 0.1
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
qsize = np.hypot(*x) * 2
yuval-alaluf commented 3 years ago

Sorry, I am a bit unfamiliar with the technical implementation details of the alignment. The alignment code was taken from the official FFHQ alignment script,

qo4on commented 3 years ago

Ok, thanks. There is another issue I found with restyle_psp. If you look at the first row of images in the post above you may notice that the face is a bit reddish. These images were generated by feeding the first iteration latents of restyle_psp to toonify model. I tried opts1.n_iters_per_batch = 2, 3, 4 and 5 and toonified images became more and more reddish. Why does this happen? The code I used:

def get_avg_image(net):
    avg_image = net(net.latent_avg.unsqueeze(0),
                    input_code=True,
                    randomize_noise=False,
                    return_latents=False,
                    average_code=True)[0]
    avg_image = avg_image.to('cuda').float().detach()
    return avg_image

def run_on_batch(inputs, net1, net2, opts1, opts2, avg_image):
    y_hat, latent = None, None
    results_batch = {idx: [] for idx in range(inputs.shape[0])}
    results_latent = {idx: [] for idx in range(inputs.shape[0])}
    for iter in range(opts1.n_iters_per_batch):
        if iter == 0:
            avg_image_for_batch = avg_image.unsqueeze(0).repeat(inputs.shape[0], 1, 1, 1)
            x_input = torch.cat([inputs, avg_image_for_batch], dim=1)
        else:
            x_input = torch.cat([inputs, y_hat], dim=1)

        y_hat, latent = net1.forward(x_input,
                                    latent=latent,
                                    randomize_noise=False,
                                    return_latents=True,
                                    resize=opts1.resize_outputs)

        # store intermediate outputs
        if iter == opts1.n_iters_per_batch - 1:
            for idx in range(inputs.shape[0]):
                results_batch[idx].append(y_hat[idx])

        # resize input to 256 before feeding into next iteration
        y_hat = net1.face_pool(y_hat)

    # iteratively translate using the resulting latent and generated image
    for iter in range(opts2.n_iters_per_batch):
        x_input = torch.cat([inputs, y_hat], dim=1)
        y_hat, latent = net2.forward(x_input,
                                     latent=latent,
                                     randomize_noise=False,
                                     return_latents=True,
                                     resize=opts2.resize_outputs)
        for idx in range(inputs.shape[0]):
            results_batch[idx].append(y_hat[idx])
        y_hat = net1.face_pool(y_hat)

    return results_batch

img_transforms = EXPERIMENT_DATA_ARGS['restyle_psp_toonify']['transform']
transformed_image = img_transforms(input_image)

opts1.n_iters_per_batch = 1
opts2.n_iters_per_batch = 5
opts1.resize_outputs = opts2.resize_outputs = False    # generate outputs at full resolution

with torch.no_grad():
    avg_image = get_avg_image(net1)
    tic = time.time()
    result_batch = run_on_batch(transformed_image.unsqueeze(0).cuda(),
                                net1, net2, opts1, opts2, avg_image)
    toc = time.time()
    print('Inference took {:.4f} seconds.'.format(toc - tic))
yuval-alaluf commented 3 years ago

This reddish effect could happen because the inversion is moving away from the "good" regions that StyleGAN was trained on. As we move away from the good regions we may possibly get artifacts like the reddish effect. This is not something I have examined thoroughly, but just a hunch.