Open chainyo opened 1 year ago
Hey! For a quick fix, can you try training with bf16
instead, by setting mixed_precision: bf16
in the accelerate config
or if you use custom deepspeed
config, by replacing "fp16" entry with
"bf16": {
"enabled": true
}
HI @reciprocated, thanks for the tip, but I managed to make it works with flan-t5-xl
only and without bf16
.
When I launch the trlx
training with flan-t5-large,
I can't go through the overflow and loss scale-down issue.
So for the moment, I will stick to the xl
version of my model.
Another issue I see is by looking at the output examples. They are super trash:
Examples using CarperAI
models:
Examples using my sft flan-t5 model:
I don't understand why the outputs are so bad. Because when I do simple generate inference with the model, the outputs are great.
Okay then, can you share a wandb link to your run? I'm fairly certain that outputs are generated properly in general, so that it's the training itself (or rather its hyperparameters) which might be the culprit of the behavior you're seeing.
I use the default config I found in the example file:
Could it come from the fact I don't specify the max_length
for my tokenizer? I use 1664
and not 512
because I have long input sequences.
I don't see if it's possible to specify a max_length
for the TokenizerConfig
here:
I'm relaunching a training using a custom tokenizer config file and not the one from the base repo. 🤔
Oh god. It was only a tokenizer max_length
problem.
Look at this clear output 🤗
I will train a flan-t5-xl
and then see if it fixes the OVERFLOW issue for flan-t5-large
. If so, I will close this issue.
Edit: After more than 10 hours of training (not finished yet), the output examples are still trash. 🧠
So after 5000 steps:
wandb: Waiting for W&B process to finish... (success).
wandb:
wandb: Run history:
wandb: awac_weight/max ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: awac_weight/mean ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: awac_weight/min ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: awac_weight/std ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: learning_rate_group_0 ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: losses/loss █▆▄▄▄▃▃▃▃▃▃▃▃▂▃▃▂▂▂▂▃▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▂▂▁
wandb: losses/loss_awac █▆▃▄▃▃▃▂▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁
wandb: losses/loss_cql ████▇▇▇▆▆▆▅▅▅▄▄▄▃▄▃▃▄▄▃▃▂▃▃▃▃▃▂▂▂▂▂▂▂▂▃▁
wandb: losses/loss_q ▇▄▄▂▂▂▂▂▂▂▂▃▄▃▄█▄▃▂▂▃▆▃▃▂▄▃▃▃▂▂▃▃▅▃▂▁▆▃▅
wandb: losses/loss_v ▄▃▃▂▁▄▃▅▃▅▆▆▆▄▇█▅▆▄▅▆▃▂▄▃▄▃▃▂▃▂▂▄▁▃▂▃▂▂▃
wandb: metrics/rewards@beta=1 ▁█▃▆█▃
wandb: metrics/rewards@beta=2 ▅▄▁▂▅█
wandb: metrics/rewards@beta=3 █▁▄▃▅▅
wandb: qvalues/0/max ▆▃▂▂▂▂▂▁▂▃▁▁▂▃▁▃▅▇▃▄█▆▆▄▇▄▅▅▅▅▆▄▅▆▆▆▅▅▆▆
wandb: qvalues/0/mean █▇▇▇▇▇▆▇▆▆▆▅▅▄▄▂▄▄▅▄▅▃▃▅▁▁▂▄▄▅▅▃▃▄▃▃▂▃▃▅
wandb: qvalues/0/min ▆▇▇███▇▇▇▇▇▅▄▅▅▅▅▅▅▄▆▅▄▅▃▄▅▁▅▄▄▄▄▅▅▃▄▃▅▅
wandb: qvalues/0/std ▄▂▂▁▁▁▂▁▁▂▂▂▃▃▃▄▅▄▃▄▄▅▆▅█▆▆▆▅▄▇▅▇▆▆▆▇▆▄▇
wandb: qvalues/1/max ▄▃▄▃▃▂▂▁▂▃▂▁▁▄▁▄▅▇▅▅█▆▆▅▅▅▅▅▅▆▆▆▆█▆▆▆▄▇▇
wandb: qvalues/1/mean █▇▆▆▆▆▆▆▅▅▅▅▄▄▃▂▃▃▅▄▄▃▃▅▁▁▂▃▃▄▄▃▃▄▃▃▂▃▃▄
wandb: qvalues/1/min ▇▅▇▇█▆▇▆▇▇▇▆▄▃▃▅▅▅▆▅▆▅▅▅▃▄▅▁▅▅▅▅▅▅▅▄▄▃▅▅
wandb: qvalues/1/std ▄▄▃▂▁▂▂▂▂▂▂▂▄▄▄▅▆▅▃▄▅▅▆▆█▆▆▇▆▅█▆▇▇▇▆▇▇▅█
wandb: time/backward ▂▂▂▂▂▂▃▂▂▂▃▃▃▄▃▃███▇▇▆▆▇▂▂▃▂▂▁▁▅▁▁▁▂▁▂▂▁
wandb: time/forward ▂▁▁▂▁▂▄▂▃▂▇▆▂█▄▃▃▂█▁▂▂▂▂▂▂▅▃▃▂▂▆▂▃▃▆▂▃▇▄
wandb: time/generate █▁▅▄▅▂
wandb: time/metric ▇▄▁▄█▆
wandb: values/max ▁▁▁▁▂▁▁▂▂▂▂▃▃▄▂▄▄▃▄▃▅▇▂▆▄▄▆▅▇▅▆▅▆▇▆▆▇▆▇█
wandb: values/mean ▇▇▇██▇▇█▆█▆▆▆▅▅▅▃▅▆▄▇▆▄▇▂▁▃▆▆▆▆▆▅▅▅▂▂▇▅▆
wandb: values/min ████▇▇█▇▇█▆▆▆▆▅▅▅▆▄▄▆▅▄▆▃▂▄▂▃▁▄▄▅▂▂▁▂▁▁▄
wandb: values/std ▁▁▁▁▂▂▁▂▂▂▃▃▃▄▃▄▅▃▅▅▄▄▅▅▇▇▆▆▆▇▇▅▆█▆▇██▇█
wandb:
wandb: Run summary:
wandb: awac_weight/max 1.0
wandb: awac_weight/mean 1.0
wandb: awac_weight/min 1.0
wandb: awac_weight/std 0.0
wandb: learning_rate_group_0 0.0
wandb: losses/loss 1.7004
wandb: losses/loss_awac 0.23572
wandb: losses/loss_cql 11.51562
wandb: losses/loss_q 0.28403
wandb: losses/loss_v 0.02929
wandb: metrics/rewards@beta=1 -1.93066
wandb: metrics/rewards@beta=2 -1.82422
wandb: metrics/rewards@beta=3 -1.93164
wandb: qvalues/0/max 1.07812
wandb: qvalues/0/mean -0.29688
wandb: qvalues/0/min -1.75977
wandb: qvalues/0/std 0.49976
wandb: qvalues/1/max 1.16504
wandb: qvalues/1/mean -0.30566
wandb: qvalues/1/min -1.78906
wandb: qvalues/1/std 0.49463
wandb: time/backward 6.62192
wandb: time/forward 0.16124
wandb: time/generate 147.83096
wandb: time/metric 95.48744
wandb: values/max 1.17578
wandb: values/mean -0.16565
wandb: values/min -3.125
wandb: values/std 0.62158
wandb:
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /home/chainyo/code/llms-on-gcp/wandb/offline-run-20230322_220645-4t12ycqh
wandb: Find logs at: ./wandb/offline-run-20230322_220645-4t12ycqh/logs
Unfortunately, outputs are bad. I should miss something in the hyper-parameters configuration.
Same issue @reciprocated ^
@chainyo did you face any issue while trying to use t5 model for inference after training since config file and .bin model is missing in the saved checkpoint ?
🐛 Describe the bug
Hi, I'm trying to use
ilql
training on custom data withflan-t5-large
andflan-t5-xl
models to fine-tune them using RLHF andgpt-j-6B
as a reward model.I have completed the 1st
sft
step for the bothflan-t5
models with a custom training script and a custom dataset. -> The model is working. I can load it and use it for inference.I have completed the 2nd step, getting
gpt-j
reward_model checkpoints. The model was also fine-tuned in a firstsft
step like theflan-t5
ones. So, now I get the pytorch_model.bin of this new reward model ready to be used. I also used custom data withprompt
,chosen
, andrejected
things like mentioned in the blog post.I'm here 👋
I'm using this script from your examples: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/ilql_summarize_t5.py
So I'm trying to improve the
flan-t5
models' performances using RLHF, but I face problems when it comes to this step:Deepspeed keeps telling me there is an
OVERFLOW!
issue, and it needs to scale down the loss. And after 8-9 loss scale attempts, the training crashes.So I trained using suggested models/datasets from the example script to see how it goes on your side by default. And there are also
OVERFLOW!
issues, but after 4 loss scale attempts, it runs well.flan-t5
models were trained withbf16
orfp32
?Any hint would be so much appreciated 🤗
Which trlX version are you using?
trlx
main branchAdditional system and package information
GCP Pytorch 1.13.1, CUDA 11.6, 2xA100 40GBs, transformers>=4.26.1, deepspeed>=0.8.2, accelerate>=0.16.0