dome272 / MaskGIT-pytorch

Pytorch implementation of MaskGIT: Masked Generative Image Transformer (https://arxiv.org/pdf/2202.04200.pdf)
MIT License
405 stars 35 forks source link

vq_gan reconstruction results blurry using default code #3

Open Torment123 opened 2 years ago

Torment123 commented 2 years ago

Hi, thank you for this very interesting work! I'm currently trying to train the vq-gan part on my few-shot dataset (e.g. ~300 dog or cat images) with resolution 256x256. However using the default settings on the code, after training for 200 epochs the reconstruction results still look kinda blurry (as shown below, first row is real image, second row is reconstructed image after training) 100_20 199_60 199_10

And after comparing the code with the setup in the paper, I currently found two differences:

  1. the default embedding dimension is 256 in the code, where it is 768 used in the paper
  2. the non-local block is single head attention, where the paper used 8-head attention

I'm not sure whether these differences may cause the blurry results of this extent? or are there any other factors I need to pay attention to ? Thanks!

dome272 commented 2 years ago

Hello, the first stage of the VQGAN implementation might not be perfect. The embedding dimension might be a worthy thing to try, I used 256 because my machine couldnt do much more. Also check out the original repo configs https://github.com/CompVis/taming-transformers/tree/master/configs and see if you could change some hyperparameters, also I believe that they used a latent dim of 256, where did you find the 768?. Also thanks for pointing out the missing heads, but I would actually think that the authors used SingleHead Attention. If you look at the original VQGAN repo: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/diffusionmodules/model.py#L376 I see no signs of Multiple Heads.

And I actually dont know about other things that could be different from the original repo. The encoder and decoder are exactly the same in terms of architecture and parameter count, the codebook was the same as they used and the discriminator was taken from the PatchGAN paper, just as the authors did.

So I believe that it could only be hyperparameters which should be changed.

Have you tried the main repo paper of VQGAN and tried your dataset? How did it perform there? As expected?

Torment123 commented 2 years ago

Thank you for the prompt detailed reply! I haven't tried the images on VQGAN yet, and am currently trying to run it. Also, about the settings in the paper, it is in section 4.1 Experimental Setup in the paper, second paragraph, where it says all models have 24 layers, 8 attention heads, 768 embedding dimensions and 3072 hidden dimensions. I'm not sure is this the correct place to look? Also, what does 'hidden dimension" referring to?

dome272 commented 2 years ago

where it says all models have 24 layers, 8 attention heads, 768 embedding dimensions and 3072 hidden dimensions

This is for the second stage of the VQGAN, so for the transformer part. This has nothing to do with the reconstruction part which is learned first. The correct place to look is the original VQGAN paper https://arxiv.org/pdf/2012.09841.pdf. MaskGIT is building on top and just replaces the second stage and leaves the reconstruction learning untouched.

Torment123 commented 2 years ago

Thanks for your clarification, this helps me a lot

dome272 commented 2 years ago

No problem, also just drop your adjustments you made hear if you found something to work for your case. Maybe others can benefit from it too.

Torment123 commented 2 years ago

Sure. Also, could you share some pretrained MaskGIT so that we can play with a bit? Thanks!

Blackkinggg commented 1 year ago

I have meet the seem question in reconstruction, and when start the discriminator the image will add some patches of noise. image

zhuqiangLu commented 1 year ago

I met similar problem with the Taming-Transformer VQGAN implementation. I think a quick fix is to let the generator be trained for a couple epoch, then start the training of discriminator, and set smaller (in my case, it is 0.2) discriminator loss weight on generated image. https://github.com/CompVis/taming-transformers/issues/93