Closed gagan3012 closed 4 months ago
I'm experiencing same issue :( seems like the grad_norm suddenly diverges to infinity after some iterations.
@gagan3012 @hbin0701 do you see this with some specific dataset? Here is my run of orpo.py
:
https://wandb.ai/krasul/huggingface/runs/rqu2awe3?nw=nwuserkrasul
using:
python examples/scripts/orpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 1e-3 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="orpo_anthropic_hh" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
I was using mistral 0.2
Hello @gagan3012, I just saw this issue and would like to add some comments!
Although I do not know the specific environment or dataset you are using, it is generally recommended that you use a lower learning rate and beta for larger models.
For example, this code for reproducing kaist-ai/mistral-orpo-capybara-7k uses a maximum learning rate of 5e-6 and beta of 0.05. (this code is not for TRL ORPOTrainer, by the way)
accelerate launch --config_file ./src/accelerate/fsdp.yaml main.py \
--lr 5e-6 \
--torch_compile False \
--beta 0.05 \
--lr_scheduler_type inverse_sqrt \
--warmup_steps 100 \
--model_name mistralai/Mistral-7B-v0.1 \
--data_name argilla/distilabel-capybara-dpo-7k-binarized \
--num_train_epochs 3 \
--optim adamw_bnb_8bit \
--gradient_accumulation_steps 1 \
--prompt_max_length 1792 \
--response_max_length 2048 \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--num_proc 8 \
--flash_attention_2
I am not sure which dataset you are training to, but I would start with a beta of 0.1 and a learning rate of 5e-6 for the first. I will add some general guidelines for selecting the learning rate and beta by model size/dataset style in this repo by this week!
Hello, When using the Orpo repo, i don't face this issue, but I face this issue when I use TRL, which is very puzzling
Is your prompt preparation correct?
TRL expects the "chosen" and "rejected" columns to be a) formatted (but not tokenized) and b) to EXCLUDE the prompt.
TRL also does not add any bos or eos tokens, so you need to do that in the chat_template. Further, since you'll be formatting chosen and rejected columns without the prompt, you need to ensure that the bos is NOT included there...
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
When the prompt exceeds the max_length, the log probabilities for both chosen and rejected turn to NaN. Consider filtering out cases where the prompt is longer than the max_length or max_prompt_len. The reason for trimming cases where the prompt exceeds max_prompt_len is that if the chosen or rejected segments are significantly shorter than the prompt, it may hinder effective learning.
I also have a similar problem, but it's different from what is mentioned above. My dataset doesn't have prompts, and all the prompts are concatenated with chosen/rejected in the dataset.
This line: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py#L618
There is a possibility that torch.exp(policy_chosen_logps) or torch.exp(policy_rejected_logps) will be "1". Then torch.log1p results NaN.
This line: https://github.com/huggingface/trl/blob/main/trl/trainer/orpo_trainer.py#L618
There is a possibility that torch.exp(policy_chosen_logps) or torch.exp(policy_rejected_logps) will be "1". Then torch.log1p results NaN.
any solution?
Adding eps=1e-5 to log1p param work fine for me
Adding eps=1e-5 to log1p param work fine for me
you mean torch.log1p(x + eps)?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
When the prompt exceeds the max_length, the log probabilities for both chosen and rejected turn to NaN. Consider filtering out cases where the prompt is longer than the max_length or max_prompt_len. The reason for trimming cases where the prompt exceeds max_prompt_len is that if the chosen or rejected segments are significantly shorter than the prompt, it may hinder effective learning.
I was facing similar issues with 'nans' and the problem went away when I filtered out of my dataset examples where the length of the prompt + chosen/reward exceeded a certain lenght
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.5000000000000002e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': -3.114448070526123, 'logits/chosen': -3.114448070526123, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 5.0000000000000004e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 7.500000000000001e-08, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.0000000000000001e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.2500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.5000000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 1.7500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.0000000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0} {'loss': 0.0, 'grad_norm': nan, 'learning_rate': 2.2500000000000002e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/accuracies': 0.0, 'rewards/margins': nan, 'logps/rejected': nan, 'logps/chosen': nan, 'logits/rejected': nan, 'logits/chosen': nan, 'nll_loss': nan, 'log_odds_ratio': nan, 'log_odds_chosen': nan, 'epoch': 0.0}
using the example script orpo.py i get this error