huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.93k stars 1.25k forks source link

TRL orpo gives everything Nan #1473

Closed gagan3012 closed 4 months ago

gagan3012 commented 7 months ago

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.5000000000000002e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': -3.114448070526123, 'logits/chosen': -3.114448070526123, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 5.0000000000000004e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 7.500000000000001e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.0000000000000001e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.2500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.5000000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.7500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.0000000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.2500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0}

using the example script orpo.py i get this error

hbin0701 commented 7 months ago

I'm experiencing same issue :( seems like the grad_norm suddenly diverges to infinity after some iterations.

kashif commented 7 months ago

@gagan3012 @hbin0701 do you see this with some specific dataset? Here is my run of orpo.py:

https://wandb.ai/krasul/huggingface/runs/rqu2awe3?nw=nwuserkrasul

using:

python examples/scripts/orpo.py \
    --model_name_or_path=gpt2 \
    --per_device_train_batch_size 4 \
    --max_steps 1000 \
    --learning_rate 1e-3 \
    --gradient_accumulation_steps 1 \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir="orpo_anthropic_hh" \
    --optim rmsprop \
    --warmup_steps 150 \
    --report_to wandb \
    --bf16 \
    --logging_first_step \
    --no_remove_unused_columns \
    --use_peft \
    --lora_r=16 \
    --lora_alpha=16
gagan3012 commented 7 months ago

I was using mistral 0.2

jiwooya1000 commented 7 months ago

Hello @gagan3012, I just saw this issue and would like to add some comments!

Although I do not know the specific environment or dataset you are using, it is generally recommended that you use a lower learning rate and beta for larger models.

For example, this code for reproducing kaist-ai/mistral-orpo-capybara-7k uses a maximum learning rate of 5e-6 and beta of 0.05. (this code is not for TRL ORPOTrainer, by the way)

accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
    --lr 5e-6 \
    --torch_compile False \
    --beta 0.05 \
    --lr_scheduler_type inverse_sqrt \
    --warmup_steps 100 \
    --model_name mistralai/Mistral-7B-v0.1 \
    --data_name argilla/distilabel-capybara-dpo-7k-binarized \
    --num_train_epochs 3 \
    --optim adamw_bnb_8bit \
    --gradient_accumulation_steps 1 \
    --prompt_max_length 1792 \
    --response_max_length 2048 \
    --per_device_train_batch_size 8 \
    --per_device_eval_batch_size 8 \
    --num_proc 8 \
    --flash_attention_2

I am not sure which dataset you are training to, but I would start with a beta of 0.1 and a learning rate of 5e-6 for the first. I will add some general guidelines for selecting the learning rate and beta by model size/dataset style in this repo by this week!

gagan3012 commented 7 months ago

Hello, When using the Orpo repo, i don't face this issue, but I face this issue when I use TRL, which is very puzzling

RonanKMcGovern commented 7 months ago

Is your prompt preparation correct?

TRL expects the "chosen" and "rejected" columns to be a) formatted (but not tokenized) and b) to EXCLUDE the prompt.

TRL also does not add any bos or eos tokens, so you need to do that in the chat_template. Further, since you'll be formatting chosen and rejected columns without the prompt, you need to ensure that the bos is NOT included there...

github-actions[bot] commented 6 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

conseq2 commented 6 months ago

When the prompt exceeds the max_length, the log probabilities for both chosen and rejected turn to NaN. Consider filtering out cases where the prompt is longer than the max_length or max_prompt_len. The reason for trimming cases where the prompt exceeds max_prompt_len is that if the chosen or rejected segments are significantly shorter than the prompt, it may hinder effective learning.

paulcx commented 6 months ago

I also have a similar problem, but it's different from what is mentioned above. My dataset doesn't have prompts, and all the prompts are concatenated with chosen/rejected in the dataset.

poutyface commented 6 months ago

This line: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py#L618

There is a possibility that torch.exp(policy_chosen_logps) or torch.exp(policy_rejected_logps) will be "1". Then torch.log1p results NaN.

paulcx commented 5 months ago

This line: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py#L618

There is a possibility that torch.exp(policy_chosen_logps) or torch.exp(policy_rejected_logps) will be "1". Then torch.log1p results NaN.

any solution?

poutyface commented 5 months ago

Adding eps=1e-5 to log1p param work fine for me

paulcx commented 5 months ago

Adding eps=1e-5 to log1p param work fine for me

you mean torch.log1p(x + eps)?

github-actions[bot] commented 5 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

KoutchemeCharles commented 3 months ago

When the prompt exceeds the max_length, the log probabilities for both chosen and rejected turn to NaN. Consider filtering out cases where the prompt is longer than the max_length or max_prompt_len. The reason for trimming cases where the prompt exceeds max_prompt_len is that if the chosen or rejected segments are significantly shorter than the prompt, it may hinder effective learning.

I was facing similar issues with 'nans' and the problem went away when I filtered out of my dataset examples where the length of the prompt + chosen/reward exceeded a certain lenght