hp-l33 / AiM

Official PyTorch Implementation of "Scalable Autoregressive Image Generation with Mamba"
MIT License
105 stars 6 forks source link

How to load pre-trained model mamba_370m #8

Closed Doctor-James closed 2 hours ago

Doctor-James commented 3 hours ago

Thank you for your excellent work. I would like to try training AiM from scratch. For example, when training AiM_L, I encountered many size and key mismatches while loading the pre-trained model from mamba_370m. Could you provide some suggestions on how to resolve this? How did you handle the model loading?

Doctor-James commented 3 hours ago

It might be related to the version of mamba_ssm. Could you share which version of mamba_ssm you're using and which pre-trained model you loaded? For reference, I'm using mamba-370m and mamba_ssm==2.2.2

hp-l33 commented 3 hours ago

Hi! Unfortunately, due to several modifications we made to the model architecture—such as incorporating class embeddings and AdaLN—the weights of AiM do not precisely match those of the original mamba model. In fact, when training AiM, we started entirely from scratch and did not load any pre-trained weights.

hp-l33 commented 3 hours ago

The configuration of training script parameters is crucial. Here are the specific parameters we used for training AiM-L, for your reference:

accelerate launch --num_processes=16 --num_machines=2 --main_process_ip=... --main_process_port=... --machine_rank=... train_stage2.py --aim-model AiM-L --dataset /your/data/path/ --vq-ckpt /your/ckpt/path/vq_f16.pt --batch-size 128 --lr 8e-4 --epochs 300

Please note that if the global batch size changes, the learning rate should be adjusted accordingly. We followed a setting of 1e-4 per 256 batch size. In the script above, we used 16 GPUs, with each GPU having a batch size of 128, making the total batch size 2048. Therefore, the learning rate is calculated as 2048 / 256 * 1e-4 = 8e-4.

Doctor-James commented 2 hours ago

Thank you for your suggestion. I made a mistake and realized that you're using the mamba2 model. After switching to mamba2_370m, I was able to load most of the weights. The main missing weights are for adaln and position_embeddings. I believe partially loading the weights might still be effective. Do you have any recommendations?

hp-l33 commented 2 hours ago

I'm glad to see that you've resolved the issue. I'm also curious whether loading the pre-trained weights trained on language datasets would be beneficial. Looking forward to your experiments, and best wishes!