CarperAI / trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
MIT License
4.43k stars 469 forks source link

Flant-t5-large Deepspeed OVERFLOW! issues + bad outputs after trlx training #383

Open chainyo opened 1 year ago

chainyo commented 1 year ago

🐛 Describe the bug

Hi, I'm trying to use ilql training on custom data with flan-t5-large and flan-t5-xl models to fine-tune them using RLHF and gpt-j-6B as a reward model.

  1. I have completed the 1st sft step for the both flan-t5 models with a custom training script and a custom dataset. -> The model is working. I can load it and use it for inference.

  2. I have completed the 2nd step, getting gpt-j reward_model checkpoints. The model was also fine-tuned in a first sft step like the flan-t5 ones. So, now I get the pytorch_model.bin of this new reward model ready to be used. I also used custom data with prompt, chosen, and rejected things like mentioned in the blog post.

  3. I'm here 👋

I'm using this script from your examples: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/ilql_summarize_t5.py

So I'm trying to improve the flan-t5 models' performances using RLHF, but I face problems when it comes to this step: CleanShot 2023-03-20 at 17 17 23

Deepspeed keeps telling me there is an OVERFLOW! issue, and it needs to scale down the loss. And after 8-9 loss scale attempts, the training crashes.

So I trained using suggested models/datasets from the example script to see how it goes on your side by default. And there are also OVERFLOW! issues, but after 4 loss scale attempts, it runs well.

CleanShot 2023-03-20 at 17 20 48

Any hint would be so much appreciated 🤗

Which trlX version are you using?

trlx main branch

Additional system and package information

GCP Pytorch 1.13.1, CUDA 11.6, 2xA100 40GBs, transformers>=4.26.1, deepspeed>=0.8.2, accelerate>=0.16.0

maxreciprocate commented 1 year ago

Hey! For a quick fix, can you try training with bf16 instead, by setting mixed_precision: bf16 in the accelerate config or if you use custom deepspeed config, by replacing "fp16" entry with

"bf16": {
     "enabled": true
 }
chainyo commented 1 year ago

HI @reciprocated, thanks for the tip, but I managed to make it works with flan-t5-xl only and without bf16.

When I launch the trlx training with flan-t5-large, I can't go through the overflow and loss scale-down issue. So for the moment, I will stick to the xl version of my model.

Another issue I see is by looking at the output examples. They are super trash:

I don't understand why the outputs are so bad. Because when I do simple generate inference with the model, the outputs are great.

maxreciprocate commented 1 year ago

Okay then, can you share a wandb link to your run? I'm fairly certain that outputs are generated properly in general, so that it's the training itself (or rather its hyperparameters) which might be the culprit of the behavior you're seeing.

chainyo commented 1 year ago

I use the default config I found in the example file:

CleanShot 2023-03-22 at 22 50 35

Could it come from the fact I don't specify the max_length for my tokenizer? I use 1664 and not 512 because I have long input sequences.

I don't see if it's possible to specify a max_length for the TokenizerConfig here:

https://github.com/CarperAI/trlx/blob/b0c4ea9fd4dd2ed09b80aa33c12bc349f02551a1/trlx/data/configs.py#L77-L100

I'm relaunching a training using a custom tokenizer config file and not the one from the base repo. 🤔

chainyo commented 1 year ago

Oh god. It was only a tokenizer max_length problem.

Look at this clear output 🤗

CleanShot 2023-03-22 at 23 06 59

I will train a flan-t5-xl and then see if it fixes the OVERFLOW issue for flan-t5-large. If so, I will close this issue.

Edit: After more than 10 hours of training (not finished yet), the output examples are still trash. 🧠

CleanShot 2023-03-23 at 09 15 40

chainyo commented 1 year ago

So after 5000 steps:

wandb: Waiting for W&B process to finish... (success).
wandb: 
wandb: Run history:
wandb:        awac_weight/max ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:       awac_weight/mean ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:        awac_weight/min ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:        awac_weight/std ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:  learning_rate_group_0 ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:            losses/loss █▆▄▄▄▃▃▃▃▃▃▃▃▂▃▃▂▂▂▂▃▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▂▂▁
wandb:       losses/loss_awac █▆▃▄▃▃▃▂▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁
wandb:        losses/loss_cql ████▇▇▇▆▆▆▅▅▅▄▄▄▃▄▃▃▄▄▃▃▂▃▃▃▃▃▂▂▂▂▂▂▂▂▃▁
wandb:          losses/loss_q ▇▄▄▂▂▂▂▂▂▂▂▃▄▃▄█▄▃▂▂▃▆▃▃▂▄▃▃▃▂▂▃▃▅▃▂▁▆▃▅
wandb:          losses/loss_v ▄▃▃▂▁▄▃▅▃▅▆▆▆▄▇█▅▆▄▅▆▃▂▄▃▄▃▃▂▃▂▂▄▁▃▂▃▂▂▃
wandb: metrics/rewards@beta=1 ▁█▃▆█▃
wandb: metrics/rewards@beta=2 ▅▄▁▂▅█
wandb: metrics/rewards@beta=3 █▁▄▃▅▅
wandb:          qvalues/0/max ▆▃▂▂▂▂▂▁▂▃▁▁▂▃▁▃▅▇▃▄█▆▆▄▇▄▅▅▅▅▆▄▅▆▆▆▅▅▆▆
wandb:         qvalues/0/mean █▇▇▇▇▇▆▇▆▆▆▅▅▄▄▂▄▄▅▄▅▃▃▅▁▁▂▄▄▅▅▃▃▄▃▃▂▃▃▅
wandb:          qvalues/0/min ▆▇▇███▇▇▇▇▇▅▄▅▅▅▅▅▅▄▆▅▄▅▃▄▅▁▅▄▄▄▄▅▅▃▄▃▅▅
wandb:          qvalues/0/std ▄▂▂▁▁▁▂▁▁▂▂▂▃▃▃▄▅▄▃▄▄▅▆▅█▆▆▆▅▄▇▅▇▆▆▆▇▆▄▇
wandb:          qvalues/1/max ▄▃▄▃▃▂▂▁▂▃▂▁▁▄▁▄▅▇▅▅█▆▆▅▅▅▅▅▅▆▆▆▆█▆▆▆▄▇▇
wandb:         qvalues/1/mean █▇▆▆▆▆▆▆▅▅▅▅▄▄▃▂▃▃▅▄▄▃▃▅▁▁▂▃▃▄▄▃▃▄▃▃▂▃▃▄
wandb:          qvalues/1/min ▇▅▇▇█▆▇▆▇▇▇▆▄▃▃▅▅▅▆▅▆▅▅▅▃▄▅▁▅▅▅▅▅▅▅▄▄▃▅▅
wandb:          qvalues/1/std ▄▄▃▂▁▂▂▂▂▂▂▂▄▄▄▅▆▅▃▄▅▅▆▆█▆▆▇▆▅█▆▇▇▇▆▇▇▅█
wandb:          time/backward ▂▂▂▂▂▂▃▂▂▂▃▃▃▄▃▃███▇▇▆▆▇▂▂▃▂▂▁▁▅▁▁▁▂▁▂▂▁
wandb:           time/forward ▂▁▁▂▁▂▄▂▃▂▇▆▂█▄▃▃▂█▁▂▂▂▂▂▂▅▃▃▂▂▆▂▃▃▆▂▃▇▄
wandb:          time/generate █▁▅▄▅▂
wandb:            time/metric ▇▄▁▄█▆
wandb:             values/max ▁▁▁▁▂▁▁▂▂▂▂▃▃▄▂▄▄▃▄▃▅▇▂▆▄▄▆▅▇▅▆▅▆▇▆▆▇▆▇█
wandb:            values/mean ▇▇▇██▇▇█▆█▆▆▆▅▅▅▃▅▆▄▇▆▄▇▂▁▃▆▆▆▆▆▅▅▅▂▂▇▅▆
wandb:             values/min ████▇▇█▇▇█▆▆▆▆▅▅▅▆▄▄▆▅▄▆▃▂▄▂▃▁▄▄▅▂▂▁▂▁▁▄
wandb:             values/std ▁▁▁▁▂▂▁▂▂▂▃▃▃▄▃▄▅▃▅▅▄▄▅▅▇▇▆▆▆▇▇▅▆█▆▇██▇█
wandb: 
wandb: Run summary:
wandb:        awac_weight/max 1.0
wandb:       awac_weight/mean 1.0
wandb:        awac_weight/min 1.0
wandb:        awac_weight/std 0.0
wandb:  learning_rate_group_0 0.0
wandb:            losses/loss 1.7004
wandb:       losses/loss_awac 0.23572
wandb:        losses/loss_cql 11.51562
wandb:          losses/loss_q 0.28403
wandb:          losses/loss_v 0.02929
wandb: metrics/rewards@beta=1 -1.93066
wandb: metrics/rewards@beta=2 -1.82422
wandb: metrics/rewards@beta=3 -1.93164
wandb:          qvalues/0/max 1.07812
wandb:         qvalues/0/mean -0.29688
wandb:          qvalues/0/min -1.75977
wandb:          qvalues/0/std 0.49976
wandb:          qvalues/1/max 1.16504
wandb:         qvalues/1/mean -0.30566
wandb:          qvalues/1/min -1.78906
wandb:          qvalues/1/std 0.49463
wandb:          time/backward 6.62192
wandb:           time/forward 0.16124
wandb:          time/generate 147.83096
wandb:            time/metric 95.48744
wandb:             values/max 1.17578
wandb:            values/mean -0.16565
wandb:             values/min -3.125
wandb:             values/std 0.62158
wandb: 
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /home/chainyo/code/llms-on-gcp/wandb/offline-run-20230322_220645-4t12ycqh
wandb: Find logs at: ./wandb/offline-run-20230322_220645-4t12ycqh/logs

Unfortunately, outputs are bad. I should miss something in the hyper-parameters configuration.

CleanShot 2023-03-23 at 11 02 03

aleksandr-smechov commented 1 year ago

Same issue @reciprocated ^

princetyagi1 commented 1 year ago

@chainyo did you face any issue while trying to use t5 model for inference after training since config file and .bin model is missing in the saved checkpoint ?