chaiyujin / glow-pytorch

pytorch implementation of openai paper "Glow: Generative Flow with Invertible 1×1 Convolutions"
MIT License
505 stars 79 forks source link

Reconstructed images are not like the input images #20

Open swyoon opened 4 years ago

swyoon commented 4 years ago

Hi,

First of all, thank you for a nice repo.

I am trying to map an image x to a latent representation z and then map back to its reconstruction x_hat with a trained model, but x and x_hat are not similar at all.

I understand that it may not be totally identical due to Split2d layer, but the degree of difference is way too severe.

I ran the training script, and the reconstructed images shown in Tensorboard are much similar to their corresponding inputs.

Here are rough snippets that can reproduce my problem.

For defining and loading dataset and model,

from torchvision import transforms
from glow.config import JsonConfig
from glow.builder import build
from glow.trainer import Trainer
from glow.utils import load
import vision

hparams = JsonConfig('hparams/celeba.json')

transform = transforms.Compose([
    transforms.CenterCrop(hparams.Data.center_crop),
    transforms.Resize(hparams.Data.resize),
    transforms.ToTensor()])
dataset = vision.Datasets['celeba']
dataset = dataset('<path-to-CelebA>', transform=transform)

built = build(hparams, True)
load('glow_celeba.ckpt.pkg', built['graph'])
graph = built['graph']
graph. eval()

Now I want to encode a batch of images and decode them back.

img_x = torch.stack([dataset[i]['x'] for i in range(12)])
img_x = img_x.cuda()

# encode
z, nll, y_logits = graph(img_x, y_onehot=None)

# decode
x_hat = graph(z, reverse=True)

When I visualize img_x and x_hat, I found a significant discrepancy.

import matplotlib.pyplot as plt

# original images
np_img_x = img_x.permute(0,2,3,1).detach().cpu().numpy()
fig, axs = plt.subplots(ncols=5, figsize=(15,5))
for i in range(5):
    axs[i].imshow(np_img_x[i])

image

# reconstructed images
np_x_hat = x_hat.permute(0,2,3,1).detach().cpu().numpy()
fig, axs = plt.subplots(ncols=5, figsize=(15,5))
for i in range(5):
    axs[i].imshow(np_x_hat[i])

image

As you can see, those images are very different.

As a sanity check, I ran unconditioned sampling. It gives reasonably fine images, especially with apt choice eps_std, so the model is well-trained.

out = graph(z=None, reverse=True, eps_std=0.6)
np_out = out.permute(0,2,3,1).detach().cpu().numpy()
fig, axs = plt.subplots(ncols=5, figsize=(15,5))
for i in range(5):
    axs[i].imshow(np_out[i])

image

turnip001 commented 4 years ago

For a glow model, x and z should have the same shape. However, part of z is abandoned during the split2d operation. That may be the reason why x and x_hat are not similiar at all.

tenpercent commented 3 years ago
# decode
x_hat = graph(z, reverse=True)

If you leave the parameter unnamed, it assumes you supply x as input, leaving z as None. Try

# decode
x_hat = graph(z=z, reverse=True)