Closed SURABHI-GUPTA closed 4 years ago
I think you don't need special adjustments. For hyperparameter side learning rate schedule could be worth trying, but I think default settings will work.
@rosinality okay thanks.. I am only interested in reconstructions.. so which files should I run to get good reconstruction ? How did you get stage1_sample.png as reconstruction result ?
I didn't made the scripts for reconstruction. You only need to input images using VQ-VAE model to reconstruction.
@rosinality I am getting some error while loading the model back. I ran for 100 epochs. It is giving this error. It would be really helpful if you can share that script . Also, did you get reconstruction after running train_vqvae.py file or do I need to run other steps too for reconstruction?
You can do like this.
import torch
from torchvision import transforms
from PIL import Image
from vqvae import VQVAE
size = 256
ckpt_path = 'checkpoint/vqvae_100.pt'
img_path = 'IMAGE_PATH'
transform = transforms.Compose(
[
transforms.Resize(size),
transforms.CenterCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
vqvae = VQVAE()
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
ckpt_raw = {}
for k, v in ckpt.items():
ckpt_raw[k.replace('module.', '', 1)] = v
vqvae.load_state_dict(ckpt_raw)
vqvae.eval()
img = Image.open(img_path).convert('RGB')
img_t = transform(img).unsqueeze(0)
with torch.no_grad():
recon = vqvae(img_t)[0].squeeze(0).add(1).div(2)
train_vqvae.py will train VQ-VAE model which can do reconstruction. Other model (PixelSNAIL) is for sampling.
@rosinality Thanks. I ran this and got reconstructions blank. The LFW dataset directory structure is different. It has 6k approx folders with each person name and inside that folder are 1-2 images of that person. The size of each image is 250 X 250. How should I load my dataset and what should be the size argument for the transforms ?
I don't think you need adjustment. As 250 vs 256 size is not very different, using default size 256 would be convenient. Anyway, during training reconstruction samples will be saved in sample directory. You may want to check it.
@rosinality yes I checked that.. at the end of training reconstructions are good... but why on test images is it blank ?
Then you need to check your checkpoint.
@rosinality There was error in saving the image. But result is like this. I dont know why there is whiteness in the reconstructed image. Solved by using "recon = vqvae(img_t)[0].squeeze(0)" instead of "recon = vqvae(img_t)[0].squeeze(0).add(1).div(2)"
How did you saved the image? Images are normalized to have values between (-1, 1) during traing, so you should invert it. That is the reason why .add(1).div(2)
is used.
@rosinality I did this: "utils.save_image(torch.cat([img1, recon], -1), "recon1.png", normalize=True, range=(-1, 1))" and result is shown here.
@rosinality does your code reconstructs using latents maps as mentioned in paper, and if yes, then how many latent codes are there in your code ? can I extract the output of each latent map to visualise what details are being added by each map ?
In current implementations VQ-VAE uses 2 hierarchical latent codes. If you extracted latent codes, then you can decode using VQVAE.decode_code(top, bottom)
, and if you want to check details added by each codes, then you can just replace one of the codes to noninformative codes like all zero tensors.
Hi @rosinality, Thanks for the code. I want to train the model from scratch on LFW dataset and focus on reconstructions only. Will this code do this and also what are the steps and the hyperparameters for getting good reconstructions ?