hp-l33 / AiM

Official PyTorch Implementation of "Scalable Autoregressive Image Generation with Mamba"
MIT License
108 stars 6 forks source link

about training details #6

Closed maxin-cn closed 1 month ago

maxin-cn commented 1 month ago

Thank you so much for open-sourcing your work, it's a fantastic piece of work. I read your paper but did not find the training details. I would like to ask how many A100 GPUs you used to train the model, and what is the batch size. How many training steps will the model initially converge? Thank you very much!

hp-l33 commented 1 month ago

Hi! Thank you for your interest in our work. We trained the B, L, and XL models using 16 GPUs with a global batch size of 2048. For the largest model, we used 32 GPUs. The models tend to converge quickly during the early stages of training. For reference, here’s the loss curve of AiM-B. W B Chart 2024_10_7 22_12_21

maxin-cn commented 1 month ago

Thanks for your reply.

Doctor-James commented 3 weeks ago

Hi! Thank you for your interest in our work. We trained the B, L, and XL models using 16 GPUs with a global batch size of 2048. For the largest model, we used 32 GPUs. The models tend to converge quickly during the early stages of training. For reference, here’s the loss curve of AiM-B. W B Chart 2024_10_7 22_12_21

Hello, may I ask if there is an unsmoothed version of the loss curve? In my preliminary experiment, the loss hardly decreases in the early stages (batch size 256, 5k steps), and the loss remains at 8.3125. Do you have any suggestions?

hp-l33 commented 3 weeks ago

Hi! Thank you for your interest in our work. We trained the B, L, and XL models using 16 GPUs with a global batch size of 2048. For the largest model, we used 32 GPUs. The models tend to converge quickly during the early stages of training. For reference, here’s the loss curve of AiM-B. W B Chart 2024_10_7 22_12_21

Hello, may I ask if there is an unsmoothed version of the loss curve? In my preliminary experiment, the loss hardly decreases in the early stages (batch size 256, 5k steps), and the loss remains at 8.3125. Do you have any suggestions?

Could you please share the total number of epochs and the lr settings you used? Depending on the lr decay schedule, if the total epochs is reduced, it is advisable to appropriately increase the lr.

hp-l33 commented 3 weeks ago

Hi! Thank you for your interest in our work. We trained the B, L, and XL models using 16 GPUs with a global batch size of 2048. For the largest model, we used 32 GPUs. The models tend to converge quickly during the early stages of training. For reference, here’s the loss curve of AiM-B. W B Chart 2024_10_7 22_12_21

Hello, may I ask if there is an unsmoothed version of the loss curve? In my preliminary experiment, the loss hardly decreases in the early stages (batch size 256, 5k steps), and the loss remains at 8.3125. Do you have any suggestions?

This is a snippet of the loss curve and log.

image
  1 100%|██████████| 187800/187800 [18:31:54<00:00,  2.81it/s] 
  2 {'loss': 9.0234, 'grad_norm': 0.18595093488693237, 'learning_rate': 0.0007978700745473908, 'epoch': 0.8} 
  3 {'loss': 8.307, 'grad_norm': 0.2204223871231079, 'learning_rate': 0.0007957401490947817, 'epoch': 1.6} 
  4 {'loss': 8.1161, 'grad_norm': 0.21777774393558502, 'learning_rate': 0.0007936102236421725, 'epoch': 2.4} 
  5 {'loss': 8.016, 'grad_norm': 0.18479077517986298, 'learning_rate': 0.0007914802981895634, 'epoch': 3.19} 
  6 {'loss': 7.9525, 'grad_norm': 0.21088118851184845, 'learning_rate': 0.0007893503727369542, 'epoch': 3.99} 
  7 {'loss': 7.9045, 'grad_norm': 0.20796062052249908, 'learning_rate': 0.0007872204472843451, 'epoch': 4.79} 
  8 {'loss': 7.8686, 'grad_norm': 0.17890967428684235, 'learning_rate': 0.0007850905218317359, 'epoch': 5.59} 
  9 {'loss': 7.8413, 'grad_norm': 0.15948446094989777, 'learning_rate': 0.0007829605963791267, 'epoch': 6.39} 
 10 {'loss': 7.8183, 'grad_norm': 0.1650700867176056, 'learning_rate': 0.0007808306709265176, 'epoch': 7.19} 
 11 {'loss': 7.7986, 'grad_norm': 0.17891138792037964, 'learning_rate': 0.0007787007454739084, 'epoch': 7.99} 
 12 {'loss': 7.78, 'grad_norm': 0.18607960641384125, 'learning_rate': 0.0007765708200212993, 'epoch': 8.79} 
 13 {'loss': 7.7647, 'grad_norm': 0.13249890506267548, 'learning_rate': 0.0007744408945686901, 'epoch': 9.58} 
 14 {'loss': 7.7535, 'grad_norm': 0.18646657466888428, 'learning_rate': 0.000772310969116081, 'epoch': 10.38}