bytedance / 1d-tokenizer

This repo contains the code for 1D tokenizer and generator
Apache License 2.0
548 stars 24 forks source link

Why do we need two codebooks? So is this actually a 1D+2D representation learning? #5

Open juncongmoo opened 5 months ago

juncongmoo commented 5 months ago

I initially was very excited about this paper. However, after reading the code, I found there were actually two code books and two representations, where one is 1D(K=32) and another is 2D(16x16). All the other models use one code book and one presentation, why does this use two codebooks and two representations? Why not just use the 1D codewords seq for reconstruction?

cornettoyu commented 5 months ago

Hi,

Thanks for your interests in our work.

Firstly, the need of two codebooks/representations are a misunderstanding. As you can find in our demo Jupyter notebook, each image is indeed tokenized into 32 integers, which can also be de-tokenized back into an image.

I see your confusion may comes from the two-stage training which involves a MaskGIT-VQ. Here are some comments and hope they address your concerns.

  1. The final model does NOT require two codebook, or representations. Only 1D codebook in TiTok is kept. You may see the MaskGIT-VQ's embedding is like a linear layer in our decoder, and they will not exist in the latent representation.

  2. Why do we need to use MaskGIT-VQ at all? As shown in our paper Tab 3 (c) (screenshot attached), it is totally fine to train TiTok end-to-end in single stage from scratch, and it significantly outperforms the Taming-VQ or TiTok-2D counterparts significantly under the same training setting.

    image

Yet, MaskGIT-VQGAN has a much stronger training recipe (improves rFID from 7.94 to 2.28 with almost same architecture) which is NOT publicly available, which put our model at disadvantage when pushing for SOTA performance at ImageNet. Thus we use the two-stage training to benefit TiTok from MaskGIT-VQ to compensate the performance gap. A detailed discussion in also in the paper Sec 4.3 Ablation Studies 3rd paragraph.

If you (or anyone else) knows any better public VQGAN training recipe besides Taming-VQ, please let us know and we are more than happy to try it with TiTok!

Feel free to let me know if you have any other questions.

juncongmoo commented 5 months ago

I am not sure if I understand it 100% correctly, but I do see the MaskGIT's VQ embedding(pixel-codebook ) is used as a nn.Linear via weight sum(soft VQ), which is not a normal codebook lookup. But the confusion part is from its name and class definition, it is indeed a codebook(pixel codebook). So do you mean you have the pixel-codebook before the training and use it as weights in a linear layer? This seems very interesting. Did you get the pixel-codebook directly from a pre-trained MaskGIT-VQ model such as open-muse or amused? or train it from scratch? And when you do the warm-up training, what are the loss functions?

For the Tab 3 (c), did you use the same model architecture(1 encoder + 2 decoders(one titokdecoder + one pixeldecoder))? Can we remove the pixel-decoder and use titokdecoder to reconstruct the image directly from titok space to pixel space?

You mentioned the gap is because MaskGIT-VQGAN has a much stronger training recipe. Can you please clarify it is the training recipe of MaskGIT or VQGAN? If it is MaskGIT's, I know there is a recent paper to reproduce it: https://arxiv.org/abs/2310.14400 . If it is VQGAN, their official repo has the training code.

btw: I just read the code, and it seems open-muse is used, but I found it was removed in huggingface: https://huggingface.co/openMUSE. Where did you get the off-the-shelf MaskGIT-VQGAN model then?

cornettoyu commented 5 months ago

https://arxiv.org/abs/2310.14400

Hi, thanks for the valuable comments.

To being with, I would like to define some terms that could be confusing in MaskGIT. The MaskGIT framework contains a tokenizer (we refer to as MaskGIT-VQ or MaskGIT-VQGAN), which shares the same architecture to the original VQGAN (refer as Taming-VQ) except the attention layers are removed (so technically MaskGIT-VQ does not have a stronger architecture). However, MaskGIT-VQ significantly outperforms public available Taming-VQ (rFID 2.28 vs. 7.94), assumed because of their stronger training recipe, for which no code or details are revealed in their official code or paper.

For the MaskGIT generator, we refer to as MaskGIT.

  1. In the two-stage training paradigm, we first train TiTok to reconstruct the MaskGIT-VQ's codes (with cross-entropy similar to BEIT), and then we fine-tune the model's decoder towards raw pixels. So you may envision that we incorporate the TiTok's decoder and MaskGIT-VQ's decoder into a larger decoder as our de-tokenizer. But still, we only need to use the 32 integers from TiTok's tokenizer, as shown in our demo code. So there is no real two embeddings/codebooks. From a higher level, the only thing we care is we have a tokenizer that can tokenize an image into a set of integers, and a de-tokenizer that can reconstruct an image from these integers. Current released TiTok can definitely do that with 32 integers to represent the image.

  2. We used the pytorch reimplementation of MaskGIT-VQ from open-muse, which also provided the pre-trained weight ported from the official Jax MaskGIT-VQ. I do not know why they take this down but we did most of our experiments when it was still available. It is just a pytorch version of the official Jax codebase with the same weight.

  3. For Tab 3(c) when trained with single-stage Taming-VQ setting, only ViT decoder is used, the last linear layer followed by a pixel shuffling (similar to how MAE predicts pixels) is used.

  4. This reproduction (https://arxiv.org/abs/2310.14400) only reproduces the generator part of MaskGIT, and they are using Taming-VQ instead of MaskGIT-VQ as tokenizer. As you can see, their generation results barely match MaskGIT's performance after the enhancement of CFG, which was not used by MaskGIT. This also demonstrates that MaskGIT-VQ is way more stronger than Taming-VQ, yet we unfortunately have no clue to how they can train the MaskGIT-VQ with very good scores. The official repo you referred to is Taming-VQ, as discussed above, we have experimented with it and it is not as good as MaskGIT-VQ, no matter for TiTok or original VQGAN architecture, and that's why we have to adopt a two-stage training for better results on ImageNet though it is not necessary for TiTok.

MikeWangWZHL commented 4 months ago

thank you for this very detailed explanation! could you elaborate more on what losses and hyperparameter you used for the single-stage training experiment? I tried to train the model (from scratch) on MNIST with a simple MSE loss, but it results in mode collapse (and codebook usage is very low); I wonder if you used any regularization to enbale the training? thanks in advance!

tau-yihouxiang commented 4 months ago

Thank the author for explanation. For simplify, please add the image reconstruction code, which includes image to 32 tokens and 32 tokens to image. Also, what is the loss weights of recon_loss, quantizer_loss, commitment_loss, and codebook_loss during training.

cornettoyu commented 4 months ago

Both reconstruction (image -> 32 tokens -> image) and generation (class label -> 32 tokens -> image) already exist in the demo jupyter notebook, we will provide a more explicit example in the README.