yuval-alaluf / stylegan3-editing

Official Implementation of "Third Time's the Charm? Image and Video Editing with StyleGAN3" (AIM ECCVW 2022) https://arxiv.org/abs/2201.13433
https://yuval-alaluf.github.io/stylegan3-editing/
MIT License
660 stars 72 forks source link

image -> latent vector -> back to original image #52

Open Evgeny-Ru opened 4 months ago

Evgeny-Ru commented 4 months ago

Hi, thanks for your amazing work!

I'm trying to project the existing image to latent space, then slightly change the latent vector (like via some linear transformation) and then transform the latent vector to the image again.

But firstly I've tried to reproduce the simplified case: image -> latent vector -> back to original image

I used your code from inference_playground.ipynb and the same image

experiment_type = 'restyle_pSp_ffhq' 
...
...
with torch.no_grad():
    tic = time.time()
    result_batch, result_latents = run_on_batch(inputs=transformed_image.unsqueeze(0).cuda().float(),
                                                net=net,
                                                opts=opts,
                                                avg_image=avg_image,
                                                landmarks_transform=torch.from_numpy(landmarks_transform).cuda().float())
    toc = time.time()
    print('Inference took {:.4f} seconds.'.format(toc - tic))

after I had got the latent vector result_latents I tried to transform it to the image back. Like:

import os
import re
from typing import List, Optional, Tuple, Union

import click
import dnnlib
import numpy as np
import PIL.Image
import torch

# import legacy

import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open('stylegan3-r-ffhqu-256x256.pkl', 'rb') as f:
    G = pickle.load(f)['G_ema'].cuda()  # torch.nn.Module
# z = torch.randn([1, G.z_dim]).cuda()    # latent codes
latent = result_latents[0][-1]
latent_tensor = torch.from_numpy(latent).to(device)

z = latent_tensor
c = None                                # class labels (not used in this example)
img = G(z, c)                           # NCHW, float32, dynamic range [-1, +1], no truncation

img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')

But instead of the original image, I got the image of another person (smiling woman instead of a smiling man etc). I've tried other networks (stylegan3-r-metfaces-1024x1024.pkl etc), but still can't reproduce the original image

Could you please suggest, what I'm doing wrong?