rosinality / vq-vae-2-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
Other
1.64k stars 275 forks source link

Reconstructions on LFW #48

Closed SURABHI-GUPTA closed 4 years ago

SURABHI-GUPTA commented 4 years ago

Hi @rosinality, Thanks for the code. I want to train the model from scratch on LFW dataset and focus on reconstructions only. Will this code do this and also what are the steps and the hyperparameters for getting good reconstructions ?

rosinality commented 4 years ago

I think you don't need special adjustments. For hyperparameter side learning rate schedule could be worth trying, but I think default settings will work.

SURABHI-GUPTA commented 4 years ago

@rosinality okay thanks.. I am only interested in reconstructions.. so which files should I run to get good reconstruction ? How did you get stage1_sample.png as reconstruction result ?

rosinality commented 4 years ago

I didn't made the scripts for reconstruction. You only need to input images using VQ-VAE model to reconstruction.

SURABHI-GUPTA commented 4 years ago

@rosinality I am getting some error while loading the model back. I ran for 100 epochs. It is giving this error. It would be really helpful if you can share that script . Also, did you get reconstruction after running train_vqvae.py file or do I need to run other steps too for reconstruction?

1

rosinality commented 4 years ago

You can do like this.

import torch
from torchvision import transforms
from PIL import Image

from vqvae import VQVAE

size = 256
ckpt_path = 'checkpoint/vqvae_100.pt'
img_path = 'IMAGE_PATH'

transform = transforms.Compose(
        [
            transforms.Resize(size),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

vqvae = VQVAE()

ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
ckpt_raw = {}

for k, v in ckpt.items():
    ckpt_raw[k.replace('module.', '', 1)] = v

vqvae.load_state_dict(ckpt_raw)
vqvae.eval()

img = Image.open(img_path).convert('RGB')

img_t = transform(img).unsqueeze(0)

with torch.no_grad():
    recon = vqvae(img_t)[0].squeeze(0).add(1).div(2)

train_vqvae.py will train VQ-VAE model which can do reconstruction. Other model (PixelSNAIL) is for sampling.

SURABHI-GUPTA commented 4 years ago

@rosinality Thanks. I ran this and got reconstructions blank. The LFW dataset directory structure is different. It has 6k approx folders with each person name and inside that folder are 1-2 images of that person. The size of each image is 250 X 250. How should I load my dataset and what should be the size argument for the transforms ?

rosinality commented 4 years ago

I don't think you need adjustment. As 250 vs 256 size is not very different, using default size 256 would be convenient. Anyway, during training reconstruction samples will be saved in sample directory. You may want to check it.

SURABHI-GUPTA commented 4 years ago

@rosinality yes I checked that.. at the end of training reconstructions are good... but why on test images is it blank ?

rosinality commented 4 years ago

Then you need to check your checkpoint.

SURABHI-GUPTA commented 4 years ago

@rosinality There was error in saving the image. But result is like this. I dont know why there is whiteness in the reconstructed image. recon Solved by using "recon = vqvae(img_t)[0].squeeze(0)" instead of "recon = vqvae(img_t)[0].squeeze(0).add(1).div(2)"

rosinality commented 4 years ago

How did you saved the image? Images are normalized to have values between (-1, 1) during traing, so you should invert it. That is the reason why .add(1).div(2) is used.

SURABHI-GUPTA commented 4 years ago

@rosinality I did this: "utils.save_image(torch.cat([img1, recon], -1), "recon1.png", normalize=True, range=(-1, 1))" and result is shown here. recon1

SURABHI-GUPTA commented 4 years ago

@rosinality does your code reconstructs using latents maps as mentioned in paper, and if yes, then how many latent codes are there in your code ? can I extract the output of each latent map to visualise what details are being added by each map ?

rosinality commented 4 years ago

In current implementations VQ-VAE uses 2 hierarchical latent codes. If you extracted latent codes, then you can decode using VQVAE.decode_code(top, bottom), and if you want to check details added by each codes, then you can just replace one of the codes to noninformative codes like all zero tensors.