LTH14 / mage

A PyTorch implementation of MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis
MIT License
507 stars 26 forks source link

The src of pre-trained VQGAN tokenizer #2

Closed ChangyaoTian closed 1 year ago

ChangyaoTian commented 1 year ago

Hi, could you provide more info about the src of the VQ-tokenizer used here ? It seems different from all the public ckpts in VQGAN's official repo.

Thanks a lot!

LTH14 commented 1 year ago

Hi, we use the same architecture as in MaskGIT.

ChangyaoTian commented 1 year ago

Hi, thanks for your timely reply! I still have some questions about the implementation details of the VQ-tokenizer.

  1. The ckpt provided here does not contain the loss part (i.e. discriminator & perceptual model), could you share it as well?
  2. In the vqgan config file, the loss config is exactly the same as the one in VQGAN's official repo, but MaskGIT uses a different loss instead (which is modified from StyleGAN). So, could you illustrate more about the detailed training setting of your vq-tokenizer (e.g. loss module, training epoch, lr, etc.)?
  3. As for the vq-tokenizer training part, what kind of data augmentation & loss weight do you guys use? Is it the same as VQGAN's official repo or as MaskGIT?

Looking forward to your reply~ Thank you!

LTH14 commented 1 year ago

Hi, thanks for your interest!

  1. For the vq-tokenizer checkpoint, we manually convert it from a JAX pre-trained checkpoint to PyTorch. In this repo, we use it only for inference. Therefore, we only convert the encoder, decoder, and codebook, as the loss part is not needed for inference. I also attach the original JAX checkpoint here. You can refer to it for the loss part.

  2. I notice that the current config file causes confusion. I just updated it to remove the deprecated VQGAN training and loss configs. Our vq-tokenizer training in JAX follows the exact same scheme as the vq-tokenizer training in MaskGIT. Since MaskGIT is in JAX, we adopt the VQGAN PyTorch framework for the vq-tokenizer inference in this repo just for implementation convenience.

  3. We use the exact same training scheme and loss as MaskGIT to train our vq-tokenizer. The only difference between the officially released MaskGIT vq-tokenizer and our released one is that we train it with a stronger augmentation (random resized crop scale from 0.2 to 1). The reason is illustrated in our paper's Section 4.2 and Table 10, in Appendix B, as stronger augmentation favors linear probing and weaker augmentation favors generation. For weak augmentation (w.a.) numbers reported in the paper, we use the exact same vq-tokenizer (with the same weights) as the MaskGIT vq-tokenizer.

Hope these help. Please let me know if you have other questions.

ChangyaoTian commented 1 year ago

Thanks again! Your reply helps a lot! I still have some questions regarding the original JAX checkpoint here:

  1. The value of 'generator_state' and 'discriminator_state' are None, but I found the generator's parameters in 'ema_params'. So i) do you use EMA during training? (which seems not applied by MaskGIT) ii) where are the loss part weights? Are they the same as those in 'd_optimizer'?
  2. Could you also share the whole training script of your vq-tokenizer? Cuz I cannot find the detailed scheme in MaskGIT's official repo.
  3. The training step in your ckpt is 1,720,000, which is slightly different from the original setting, i.e. 1,000,000 in MaskGIT. I would like to know if other hyper params (e.g., batch size, learning rate, etc.) are changed as well.

Looking forward to your reply~

LTH14 commented 1 year ago
  1. The EMA is not used during training. Instead, it is a trick to stabilize the final checkpoint and is also used in MaskGIT. I'm sure our vq-tokenizer uses the same training setting as the MaskGIT tokenizer (except for the aforementioned augmentation), as they are trained using the exact same training codebase and configurations. The 'target' fields in 'g_optimizer' and 'd_optimizer' contain the exact weights for the tokenizer and the discriminator, respectively.

  2. I also wish to do so. Unfortunately, I use the vq-tokenizer training script provided by Google during my internship, so I cannot decide whether to release it or not.

  3. Similar to VQGAN, both our vq-tokenizer and MaskGIT vq-tokenzier training uses a constant learning rate. So in practice we just keep it running forever and stop after we saw the loss converge. In our case, after 1000K steps the vq-tokenizer loss has only minor fluctuations, and also won't affect the second stage generation/linear probing performance. All other hyper params are the same. You can refer here for the hyper parameters.

Hope these help!

ttccxx commented 1 year ago

Hi! Thanks for your precious work and discussion. I have a specific question about training VQGAN: how do you compute the perceptual loss? I see that vanilla VQGAN applies some linear layers after feature difference between input and reconstruction (https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py), but I can not find these parameters in your shared ckpt. Do you use a different implementation?

LTH14 commented 1 year ago

Hi! Those linear layers in LPIPS loss are pre-trained and fixed during VQ-tokenizer training, so we do not need to store them in the parameters. You can get the weights for those 1x1 convs https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py#L28. This function will download the weights into "taming/modules/autoencoder/lpips".

ttccxx commented 1 year ago

Thanks for your reply! I didn't notice these important codes before, and I am now clear.

jackhu-bme commented 1 year ago

Hi ! Thanks for your interesting work and detailed clarification of this issue! I still have some questions when trying to train a tokenizer for medical image generation(I have tried that directly employing the pretrained vqgan on ImageNet for medical image inversing, many important structures were lost after inversing): (1) Why the architecture of tokenizer in mage follows maskGIT, (some attention blocks and bias of convs are removed compared to official vqgan), is this operation for reducing the computational cost, or for better performance when fintuning for downstream tasks? Which tokenizer(maskgit implementation of official one) is better for image generation performance? If I can choose? (2) If I choose the official implementation of vqgan, is it suitatble to train the mage with the official tokenizer (the masked encoding process)? My focus is the inpainting performance, not involving any need of fintuning for downstream tasks like image classification. Thanks again for your excellent work and patience! Your work really inspired me when I encountered problems in my project.

LTH14 commented 1 year ago

Thanks for your interest! A VQGAN pre-trained on ImageNet is likely to achieve relatively bad performance on medical data, as it is never trained on them. For (1), the reason is to save computations and be consistent with the implementation of VQGAN inside Google. The performance is nearly the same. For (2), you can use whatever vqgan for the mage tokenizer, as long as the vqgan tokenizes 256x256 images into 16x16 tokens.

jackhu-bme commented 1 year ago

Thanks a lot for your detailed reply! This really saves my time for comparing different implementation of tokenizers.