lucidrains / magvit2-pytorch

Implementation of MagViT2 Tokenizer in Pytorch
MIT License
501 stars 30 forks source link

Is there anyone success to train this model? #34

Open Jihun999 opened 5 months ago

Jihun999 commented 5 months ago

I tried to train this model few days. However, the reconstruction results always abnormal. If there is anyone success to train this model, can you tell me some tips for training?

Jihun999 commented 5 months ago

The reconstruction images are like solid image.

RobertLuo1 commented 5 months ago

Can you show the reconstruction images after training?

Jihun999 commented 5 months ago
example_image

It always looks like this image.

RobertLuo1 commented 5 months ago

@bridenmj How much epochs do you use? Are you working on the ImageNet Pretrain?

Jihun999 commented 5 months ago

Yes I'm working on ImageNet pretraining, It passed 12000 steps. The output image looks always the same. So, I tried LFQ in my own autoencoder, the training works well. It looks like there is something wrong in magvit2 model architecture.

RobertLuo1 commented 5 months ago

Actually I reimplement the model structure to align with the magvit2 paper. But I find that the LFQ Loss is negative and the recon loss will get converage easily with or without GAN. The reconstructed images are vague but not the solid color. What about you? @Jihun999

Jihun999 commented 5 months ago

Ok, I will reimplement the model first. Thank you for your comment.

Jason3900 commented 4 months ago

Actually I reimplement the model structure to align with the magvit2 paper. But I find that the LFQ Loss is negative and the recon loss will get converage easily with or without GAN. The reconstructed images are vague but not the solid color. What about you? @Jihun999

Hey, is it possible to share the code modification for model architecture alignment? Thanks a lot!

lucidrains commented 3 months ago

someone i know has trained it successfully.

Jiushanhuadao commented 3 months ago

wow, could i know who did it.

StarCycle commented 2 months ago

@RobertLuo1 @Jihun999 @lucidrains If you successfully trained this model, would you like to share the pretrained weights and the modified model code?

vinyesm commented 2 months ago

Hello there, Thanks @lucidrains for your work! I have successful trainings on toy data (tried it on images and video) with code in this fork https://github.com/vinyesm/magvit2-pytorch/blob/trainvideo/examples/train-on-video.py and with this video data https://huggingface.co/datasets/mavi88/phys101_frames/tree/main. What seemed to fix the issue is to stop using accelerate (I only train on one GPU).

I tried with only MSE and then also the other losses, and also with/without attend_space layers. All work but I did not try to tune hyperparameters..

Screenshot 2024-05-16 at 21 40 57 Screenshot 2024-05-16 at 21 44 15
lucidrains commented 2 months ago

thank you for sharing this Marina! I'll see if I can find the bug, and worse comes to worse, can always rewrite the training code in pytorch lightning

RobertLuo1 commented 1 month ago

Hi, recently we have devoted a lot to training the tokenizer in Magvit2, and now we have open source the tokenizer trained with imagenet. Feel free to use that. The project page is https://github.com/TencentARC/Open-MAGVIT2. Thanks @lucidrains so much for your reference code and discussions!

ashah01 commented 4 days ago

Hey @lucidrains, I trained a MAGVIT2 tokenizer without modifying your implementation of the accelerate framework. As others have experienced, I initially saw just a solid block in the results/sampled.x.gif files. However, upon loading the model weights from my most recent checkpoint, I was able to get pretty good reconstructions in a sample script that I wrote that performs inference without using the accelerate framework. Additionally, the reconstruction MSE scores were consistent with the ones observed in your training script. This means that whatever bug others are experiencing is not the result of flawed model training, but rather something going wrong with the gif rendering. gif generation inference generation

*Note: the first file is the saved gif in the results folder. The ground truth frames have a weird colour scheme because I normalized the frame pixels to be between [-1, 1]. The second file is a reconstructed frame from my inference script. MSE was ~0.011 after training on a v100 for 5 hours.