LTH14 / mage

A PyTorch implementation of MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis
MIT License
483 stars 26 forks source link

A question about vocab_size in token embbeding #53

Open tanbuzheng opened 2 months ago

tanbuzheng commented 2 months ago

Hi! When I read your source code, I found you set vocab_size = self.codebook_size + 1000 + 1 in token embbeding stage. Why not directly set vocal_size=self.codebook_size? What does the extra 1001 embeddings mean? Are these embeddings of class labels and mask tokens? Can I understand it this way, that is, when there is no class condition, vocal_size should be set to self.codebook+1?

Looking forward to your reply!

LTH14 commented 2 months ago

Hi, thanks for your interest! Yes, the vocab_size should be self.codebook_size + 1 when there is no class condition. We set it to self.codebook_size + 1000 + 1 just because our pre-trained checkpoint from JAX uses this redundant codebook size, and to load the pre-trained weights we need to keep it that way.

tanbuzheng commented 2 months ago

Thanks for your reply! It was very kind of you! But if so, how should I set self.fake_class_label? Is it reasonable to set it with any value within 0-1024?أ‿أ

LTH14 commented 2 months ago

You can actually set it to any value larger than or equal to 1024, and smaller than 1024+1000+1 -- but the pre-trained model set it to 1100 (again, a legacy issue).

tanbuzheng commented 2 months ago

OK! Thanks a lot!

tanbuzheng commented 2 months ago

By the way, the current code does not seem to contain the contrastive loss part, I would like to ask if you have any plans to release the complete training code including this part?

LTH14 commented 2 months ago

Unfortunately I don't have access to the original JAX code now, so there is no plan to release contrastive training part. However, that part is quite straight-forward if you want to re-implement it -- simply a SimCLR-based contrastive loss similar to this one.

tanbuzheng commented 1 month ago

Hello, auther! Did you re-train the VQGAN by yourself? It seems different from the pre-trained model released by VQGAN. So if I want to apply MAGE to the other datasets, what should I pay attention to when training VQGAN?

LTH14 commented 1 month ago

Yes, VQGAN is trained by ourselves. From my experience, VQGAN is much harder to train than a continuous VAE, especially the GAN loss and perceptual loss part is very important -- warmup epoch, GAN/perceptual weights, the discriminator you use, etc.. One important thing is to always visualize the reconstruction result -- you may get a low reconstruction error, but the visual quality could be quite blurry.

tanbuzheng commented 1 month ago

Thanks a lot! But I would like to ask, why don't you just use the pre-trained model provided by VQGAN? Is there any problem here?

LTH14 commented 1 month ago

The reconstruction FID of the original VQGAN is too poor (7.94), which bounds the generation performance. We follow many practices of ViT-VQGAN in our training, including larger batch size (256) and styleGAN discriminator (instead of patchGAN). The reconstruction FID of the tokenizer in MAGE is 2-3, which is significantly better than the original VQGAN.

tanbuzheng commented 1 month ago

I got it!Thank you so much!

tanbuzheng commented 1 month ago

Sorry to bother you again. My computing resources are limited, I wonder if I don't use contrastive loss or just use moco v2 in MAGE training, can I set the batch size to a small one, such as 256? The MAGE adopts ViT-B.

LTH14 commented 1 month ago

I haven't tested small batch sizes. For SimCLR-based contrastive loss (as we used), a large batch size is typically needed. If you don't use the contrastive loss, a smaller batch size might also be fine (although MAE and MAGE both use large batch sizes)

LTH14 commented 1 month ago

I would say a 256 batch size will not give you too bad performance if you don't use contrastive loss -- just maybe slightly worse than a large batch size

tanbuzheng commented 1 month ago

Ok, thanks again!

tanbuzheng commented 3 weeks ago

Hello, auther!Sorry to bother you again. According to my understanding, mage can be trained on a v100 when batchsize=64. Recently I made a preliminary attempt to train mage on a 3090 with batchsize 64, but out of memory appears. Do you have any experience solving this problem?

LTH14 commented 3 weeks ago

The MAGE-B model can be trained with batch size=64. MAGE-L can be trained with batch size=32. I never used 3090 before -- the V100 we use has 32GB memory.

tanbuzheng commented 3 weeks ago

Thank you very much! I just fund out that I was training MAGE-L. I should try to train MAGE-B instead.

tanbuzheng commented 3 weeks ago

Another question is that I accidentally found when I used VQGAN before. The ResnetBlock in the vqgan you used does not contain a real shortcut, so is this a special design?

LTH14 commented 3 weeks ago

Good catch -- It is just a stupid bug in Google's internal implementation -- but the performance is actually fine.

tanbuzheng commented 3 weeks ago

Yes, don't mind. It performs very well! Thanks again!

tanbuzheng commented 3 weeks ago

Dear author, I want to ask another queation. During the training of the mage encoder,why do you keep some masked tokens from being dropping? Why not adopting a dynamic mask dropping ratio, but instead setting the mask dropping ratio to a fixed value such as 0.5?

LTH14 commented 2 weeks ago

The dropping ratio is set to the minimum masking ratio, which is 0.5. The actual masking ratio (which determines the number of masked tokens) is sampled from [0.5, 1.0].