allenai / open-instruct

Apache License 2.0
1.22k stars 166 forks source link

Training script or params for TULU2-13B. Cannot reproduce TULU2 fine-tuned model #103

Closed ia3leonidshad closed 7 months ago

ia3leonidshad commented 8 months ago

I tried to train my own version of Tulu2, starting with Llama-2-13b and using https://huggingface.co/datasets/allenai/tulu-v2-sft-mixture, but final model is worse then yours https://huggingface.co/allenai/tulu-2-13b

Can you share training script or params you used for training?

Thanks!

hamishivi commented 8 months ago

Hi,

For training tulu 2, we used a jax repository for efficient TPU training: https://github.com/hamishivi/EasyLM/tree/main But with similar hyperparameter choices the two scripts should give similar results (I tested this for 7B and 13B models).

For finetuning:

One thing to note is that if you use lots of gradient accumulation, due to how huggingface calculates loss (averaged across tokens), the loss calculations can differ depending on how much gradient accumulation you are doing. For best replication, you should match the per-batch size, but this is hard without lots of compute. Alternatively, you can better mimick weighting every token equally by setting reduce_loss=sum in the finetune script (although this will show higher loss values than other runs due to the summing).

I hope that helps!

ia3leonidshad commented 8 months ago

Hi @hamishivi Thanks for the response. Couple of follow-up questions if you don't mind.

You had 4 devices or gradient-accumulation steps, 32 mini-batch-size per each, 128 in total? If you have config file or script with training params, can you share it with me?

Params I used for training, 8GPUs in total, 128 final batch size, base model llama2-13b --dataset_name allenai/tulu-v2-sft-mixture \ --max_seq_length 8192 \ --learning_rate 2e-5 \ --lr_scheduler_type linear \ --warmup_ratio 0.03 \ --weight_decay 0. \ --num_train_epochs 2 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 16 \

But final model is much worse, scores around ~6.0 on MT-bench instead of 6.7 (comparing to tulu2-13b)

And what you describing about loss accumulation, am I understanding this correctly, if I have mini-batch-size of 1, it leads to: tokens that are in the shorter context has higher relative weight in the final batch cause loss is normalized by sample length in each mini-batch? That's a very good point, I'll have a follow-up on that.

And last question, llama2-13b was trained with 4096 context size, have you done anything to switch to 8k (positional embeddings interpolation, param change in rotary embeddings, etc), or you just relied on rotary embeddings produced for indices 4k-8k?

hamishivi commented 8 months ago

Hi, yeah - exactly: samples with shorter lengths end up being weighted much higher when you have a mini-batch of 1 versus larger minibatches. I found using the sum loss reduction can actually result in significantly higher (> 5 point increase) AlpacaEval scores when you have minibatch size 1, although I haven't checked MT-Bench - I imagine it would be a similar effect though!

You had 4 devices or gradient-accumulation steps, 32 mini-batch-size per each, 128 in total?

4 gradient accumulation steps in a distributed training setup. The Jax codebase and distributed training is a bit different to pytorch, but here's the training command:

 python3 -m EasyLM.models.llama.llama_train \\
    --mesh_dim='-1,16,8' \\
    --dtype='bf16' \\
    --num_epochs=2 \\
    --log_freq=50 \\
    --save_model_freq=1000 \\
    --save_milestone_freq=0 \\
    --load_llama_config='13b' \\
    --update_llama_config='' \\
    --load_dataset_state='' \\
    --load_checkpoint='params::gs://hamishi-dev/easylm/llama2/13b' \\
    --tokenizer.vocab_file='gs://hamishi-dev/easylm/llama/tokenizer.model' \\
    --optimizer.type='adamw' \\
    --optimizer.adamw_optimizer.weight_decay=0.0 \\
    --optimizer.adamw_optimizer.lr=2e-5 \\
    --optimizer.adamw_optimizer.end_lr=0 \\
    --optimizer.adamw_optimizer.warmup_ratio=0.03 \\
    --optimizer.accumulate_gradient_steps=4 \\
    --train_dataset.type='tulu_json_torch' \\
    --train_dataset.text_processor.fields='[prompt],completion' \\
    --train_dataset.json_torch_dataset.path='../tulu_v2_data.jsonl' \\
    --train_dataset.json_torch_dataset.seq_length=4096 \\
    --train_dataset.json_torch_dataset.batch_size=32 \\
    --checkpointer.save_optimizer_state=False \\
    --logger.online=False \\
    --logger.output_dir="gs://hamishi-dev/easylm/llama2/tulu2_13b_fixed/

And last question, llama2-13b was trained with 4096 context size, have you done anything to switch to 8k (positional embeddings interpolation, param change in rotary embeddings, etc), or you just relied on rotary embeddings produced for indices 4k-8k?

Yeah nah, we haven't experimented with long-context extension methods yet - just relied on the vanilla rotary embedding implementation.

ia3leonidshad commented 8 months ago

Thanks for the info @hamishivi, I'll come back in a few days with an update.

ia3leonidshad commented 7 months ago

Hi @hamishivi It worked, I got a huge improvement with fixing mini-batch-size=1 issue we discussed.

hamishivi commented 7 months ago

Awesome!