huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.06k stars 1.27k forks source link

RLOO Checkpoint Issue #2342

Open asparius opened 5 days ago

asparius commented 5 days ago

System Info

Information

Tasks

Reproduction

python rloo_tldr.py --output_dir models/minimal/rloo_tldr --dataset_name trl-internal-testing/tldr-preference-sft-trl-style --dataset_test_split validation --num_ppo_epochs 2 --num_mini_batches 2 --learning_rate 3e-6 --per_device_train_batch_size 4 --gradient_accumulation_steps 16 --total_episodes 1000 --model_name_or_path EleutherAI/pythia-1b-deduped --sft_model_path cleanrl/EleutherAI_pythia-1b-dedupedsfttldr --reward_model_path cleanrl/EleutherAI_pythia-1b-dedupedrewardtldr --local_rollout_forward_batch_size 16 --missing_eos_penalty 1.0 --stop_token eos --kl_coef 0.03 --save_strategy steps --save_steps 10000 --eval_strategy steps --eval_steps 1000 --report_to wandb

Expected behavior

The above script should not save any checkpoints since training episodes is very low but it still produces checkpoints for every 5 steps similar to the previous RLOO issue #2124.

qgallouedec commented 4 days ago

It should be fixed by #2325. Could you confirm?

asparius commented 4 days ago

Saving issue is solved but training time duration has increased significantly, 1 million episodes taking 300+ hours on A100. Is this expected, is there any reference number to compare with?

qgallouedec commented 1 day ago

I can't reproduce:

# v0.12.1 (includes the fix); transformers 4.47 dev (blue)
/fsx/qgallouedec/trl/examples/scripts/rloo/rloo_tldr.py --output_dir models/minimal/rloo_tldr --dataset_name trl-internal-testing/tldr-preference-sft-trl-style --dataset_test_split validation --num_ppo_epochs 2 --num_mini_batches 2 --learning_rate 3e-6 --per_device_train_batch_size 4 --gradient_accumulation_steps 16 --total_episodes 1000 --model_name_or_path EleutherAI/pythia-1b-deduped --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr --local_rollout_forward_batch_size 16 --missing_eos_penalty 1.0 --stop_token eos --kl_coef 0.03 --save_strategy steps --save_steps 10000 --eval_strategy steps --eval_steps 1000 --report_to wandb
# TRL v0.11 (doesn't include the fix); transformers v4.45 (red)
/fsx/qgallouedec/trl/examples/scripts/rloo/rloo_tldr.py --output_dir models/minimal/rloo_tldr --num_ppo_epochs 2 --num_mini_batches 2 --learning_rate 3e-6 --per_device_train_batch_size 4 --gradient_accumulation_steps 16 --total_episodes 1000 --model_name_or_path EleutherAI/pythia-1b-deduped --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr --local_rollout_forward_batch_size 16 --missing_eos_penalty 1.0 --stop_token eos --kl_coef 0.03 --save_strategy steps --save_steps 10000 --eval_strategy steps --eval_steps 1000 --report_to wandb

W B Chart 14_11_2024, 12_08_20