rosinality / vq-vae-2-pytorch

Implementation of Generating Diverse High-Fidelity Images with VQ-VAE-2 in PyTorch
Other
1.59k stars 271 forks source link

A question about image crispness with VQ-VAE2 #43

Open francoisruty opened 4 years ago

francoisruty commented 4 years ago

Hello, first of all thanks for this interesting implementation of VQ-VAE 2 paper.

I can train this network on a dataset of mine, however reconstructed images are a little bit blurry. Quality is good overall, but crispness is nowhere near what is published in the paper.

My understanding of image blurriness with VAE in general is that it is caused by information loss (due to the information bottleneck), and the use of MSE loss (which averages the error over the image).

To me the VQ VAE2 paper, compared to a classic VAE, brings 2 new concepts: hierarchical latent maps, and quantization. However I don't see why those 2 innovations would solve the classic VAE reconstruction blurriness problem.

This implementation performs just as I expected, which means it's good, but reconstructed images are not sharp like the ones in the paper. This is not a problem with this implementation I think. The original VQVAE2 paper does not explain or describe why their architecture would yield sharp reconstructed images. I've read the paper, and it's just not there.

Or maybe the sharp images are possible only by sampling the trained prior by pixelSnail after stage 2? But to me increased sharpness cannot come from the VAE alone

What do you guys think? Am I missing something?

rosinality commented 4 years ago

Hmm images in the paper is much better? I thought it is quite similar.

In my opinion, VQ-VAE reduces blurriness of vanilla VAE by relaxing bottleneck of latent codes, by using much larger and deterministic latent codes without somewhat hard constraints like KL regularization. Actually using prior models like PixelSNAIL will not help reconstruction qualities - it is for the sampling, and VQ-VAE reconstruction quality will be the bound of the sample quality.

francoisruty commented 4 years ago

OK so we agree that reconstruction quality is the bound of sample quality (that makes sense but I wasn't sure)

It's true their VAE indeed uses residual blocks with skip connections which reduces information loss. You're right that the lack of a KL regularization constraint must improve things also.

I've tested a perceptual loss based on a VGG16 mid-layer, this improves things a bit, but I still can't match reconstructed image quality from the paper.

On which dataset did you train? I'd be interested in training on the same data than you (with this repo code as is) to see if I can achieve the advertised reconstruction crispness

rosinality commented 4 years ago

I have only tried FFHQ. I think details like learning rate schedules could affect the results. (Unfortunately, many training details are not specified in the paper.)

drtonyr commented 4 years ago

My opinion is that VQVAE2 achieves it's results with PixelSnail. The main function of VQVAE2 seems to be to reduce the number of things to model with PixelSnail. Fundamentally all AutoEncoders have a bluriness problem if the bottleneck is small enough to generate samples. So the bottleneck space has to be larger to be crisp and you need PixelSnail to contstrain it to faces. Another way of thiinking is to note that all the GPU time and parameters are taken training PixelSnail (compared with VQVAE2) so that's where all the modelling is going. I do thank you for publishing this work, I learned an enourmous amount from it.

Fanbenchao commented 4 years ago

@francoisruty I think you could try to use three scale VQ-VAE, described in their paper, Top, Medium, Bottom. It maybe work for your needs.

francoisruty commented 4 years ago

@Fanbenchao thanks for your suggestion, I think working with 3 scales is needed when you work on 512 or 1024 resolution, but if I have 3 latent maps 32, 64 and 128 when working with 256 resolution images, the task at hand is a bit too easy and the VQVAE2 network becomes not that much interesting IMO

Fanbenchao commented 4 years ago

blur this is two scales VQ-VAE, result after training 20 epochs high this is three scales VQ-VAE, result after training 20 epochs

Fanbenchao commented 4 years ago

@francoisruty maybe 8x downsampling result in the blury, so we can add more details by using 2x downsampling features in the process of the reconstruction

francoisruty commented 4 years ago

@Fanbenchao thanks for your illustrations, the reconstruction with 3 latent maps is good but don't you think it's disappointing to have to use 3 latent maps? This means you barely compress the data so I'm not sure useful abstractions are generated here

Fanbenchao commented 4 years ago

@francoisruty Yes, you're right. Maybe we should explore other improve ways. I tride an AutoEncoder with 4x downsampling and the quality of reconstruction is acceptable. The difference, except VQ operation, is that the AE has more residual block after 4x downsampling.

ekyy2 commented 1 year ago

@Fanbenchao By any chance would you be able to share how you modified the code to go from two to three levels? Thank you so much!

blur this is two scales VQ-VAE, result after training 20 epochs high this is three scales VQ-VAE, result after training 20 epochs