tongxuluo / prts

https://llm-stacking.github.io/
Apache License 2.0
33 stars 2 forks source link

Questions about model size and hyper-parameters #3

Open knight-fzq opened 5 days ago

knight-fzq commented 5 days ago

Thanks for the great work! 1) My current goal is to implement the results presented in Figure 3 of your promising paper "Stacking Your Transformers: A Closer Look at Model Growth for Efficient LLM Pre-Training." Could you provide the hyperparameter settings for base_model.sh and g_stack.sh to replicate the results in Figure 3?

2) When using the same model configuration in both the TinyLlama source code and this repository, I get different and unexpected model sizes as output. For example, I try the config settings as follows: dict( org="StatNLP-research", name="tiny_LLaMA_1b", block_size=2048, vocab_size=32000, padding_multiple=64, n_layer=22, n_head=32, n_embd=2048, rotary_percentage=1.0, parallel_residual=False, bias=False, _norm_class="FusedRMSNorm", norm_eps=1e-5, #Llama 2 use 1e-5. Llama 1 use 1e-6 _mlp_class="LLaMAMLP", intermediate_size=5632, n_query_groups=4, ),

I got "1.1B params" in TinyLlama, but got "Total parameters 137,506,048" with this repo under the platform of 8xRtx 4090. The running code is: python pretrain/run_pretrain.py \ --model_name=tiny_LLaMA_1.1B \ --name=tiny_LLaMA_1.1B \ --method=scratch \ --out_dir= placeholder \ --train_data_dir= placeholder \ --devices=8 \ --global_batch_size=1024 \ --learning_rate=1e-3 \ --min_lr=1e-4 \ --micro_batch_size=2 \ --max_step=5000 \ --warmup_steps=500 \ --log_step_interval=1 \ --eval_iters=10000 \ --save_step_interval=1000 \ --eval_step_interval=1000 \ --weight_decay=1e-1 \ --beta1=0.9 \ --beta2=0.95 \ --grad_clip=1.0 \ --decay_lr=True

Similarily, I got "Total parameters" 54,922,496 with 6L2048H and "Total parameters 170,537,216" with 24L2048H.

Thanks again and looking forward to your reply.

knight-fzq commented 5 days ago

Maybe, I found the bug in the code. When use codes:

fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters {num_parameters(model):,}") model = fabric.setup(model) the output will be "Total parameters 1,364,297,728". However, when use codes: model = fabric.setup(model) fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") fabric.print(f"Total parameters {num_parameters(model):,}") the output will be "Total parameters 170,537,216 ". Maybe I am right or meet other package problem.

Looking forward to your hyperparameter settings for base_model.sh and g_stack.sh to replicate the results in Figure 3.