huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.02k stars 1.11k forks source link

Got an abnormally high loss when training Gemma-7B. #1709

Closed smartliuhw closed 1 month ago

smartliuhw commented 2 months ago

I was trainging Gemma-2B and Gemma-7B using sfttrainer, with packing=True set. The Gemma-2B's loss was quite normal, but Gemma-7B's was abnormally high. I'm not sure why this would happen, since both of the model should use the same tokenizer. Gemma-7B: image Gemma-2B: image

environment dependency: CUDA version: 12.4 torch version: 2.3.0 trl version: 0.8.6

launch command:

torchrun --nproc_per_node=10 ../src/train.py \
    --model_type ${MODEL_TYPE} \
    --train_data $TRAIN_DATA \
    --output_dir ${OUTPUT_DIR} \
    --num_train_epochs 5 \
    --per_device_train_batch_size 1 \
    --learning_rate ${LR} \
    --lr_scheduler_type "cosine" \
    --warmup_ratio 0.1 \
    --warmup_steps 20 \
    --report_to wandb \
    --logging_dir ${LOG_DIR} \
    --logging_strategy steps \
    --logging_steps 1 \
    --save_strategy steps \
    --save_steps 200 \
    --save_total_limit 2 \
    --save_safetensors \
    --deepspeed ../src/deepspeed_config.json \
    --seed 725 \
    --bf16 \
    --do_train \
    --save_only_model \
    --max_seq_length ${MAX_LENGTH}

tranier set:

trainer = SFTTrainer(
    model,
    training_args,
    train_dataset=train_data,
    formatting_func=formatting_constant_length_func,
    packing=True,
    max_seq_length=training_args.max_seq_length,
)
vwxyzjn commented 2 months ago

How reproducible is this? I am wondering if it has anything to do with https://github.com/huggingface/transformers/pull/29285. Maybe try upgrading your transformers version?

smartliuhw commented 2 months ago

How reproducible is this? I am wondering if it has anything to do with huggingface/transformers#29285. Maybe try upgrading your transformers version?

I tried to run the script again just now and met the same situation again, the loss was around 40 in the beginning image

My transformers version is 4.41.2, which should be the latest version(or latest a week ago). Gemma-7B does take more GPU memory than other 7B models, so I think my version of transformers did fixed the issue. Even if it's true that there's some problem with the RoPE, it still can't explain why Gemma-2B works well😂

vwxyzjn commented 2 months ago

Fair point. Maybe cc @danielhanchen since he is an expert on Gemma-7B fixes 😂

smartliuhw commented 2 months ago

Fair point. Maybe cc @danielhanchen since he is an expert on Gemma-7B fixes 😂

Hi, I have done some tests and got some findings. It's important for Gemma model to add a token in front of every input, but the packing strategy will split a complete sentence into chunks, which makes some of the train samples don't start with token, leads to high loss in the beginning. I'm not sure if this would affect the model's performance in the end, could you please give me some advice? Also, Gemma-2B is facing the same problem, but got a small loss, is it because I set the larger seq_len(twice as for Gemma-7B) for Gemma-2B since I want to take full advantage of the GPU memory

danielhanchen commented 2 months ago

@vwxyzjn Thanks for tagging me :) Hi :)

@smartliuhw Oh yes we noticed the same issue with packing=True causing high losses in our blog: https://unsloth.ai/blog/gemma-bugs

image

The most important caveat for finetuning is you must add the token (red loss). The blue loss is no token. Packing with TRL works, but has a higher base loss - it’s possible Gemma does not use the T5 packing trick anymore! See the T5 paper page 12 or TensorFlow

smartliuhw commented 2 months ago

@vwxyzjn Thanks for tagging me :) Hi :)

@smartliuhw Oh yes we noticed the same issue with packing=True causing high losses in our blog: https://unsloth.ai/blog/gemma-bugs

image

The most important caveat for finetuning is you must add the token (red loss). The blue loss is no token. Packing with TRL works, but has a higher base loss - it’s possible Gemma does not use the T5 packing trick anymore! See the T5 paper page 12 or TensorFlow

Thanks a lot for your reply! It's true that Gemma does not use the trick you mentioned, I manually print part of the input_ids and decode them, the results are following:

image

May I ask how can I fix those problems? Or I should use unsloth to train my model(btw, I didn't find the doc for multi-gpu train using unsloth, could you please show me some examples?)

Thanks a lot 🥰

kdcyberdude commented 2 months ago

Despite adding <bos> token and loading model and tokenizer using unsloth, I am still getting very high loss. Single sequence example -

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma-7b-it-bnb-4bit", 
    max_seq_length = 2048,
    dtype = torch.bfloat16,
    load_in_4bit = True,
)

text = """<bos><start_of_turn>user
How many days are there in a week?<end_of_turn>
<start_of_turn>model
There are a total of 7 days in a week. They are Sunday, Monday, Tuesday, Wednesday, Thursday, Friday, and Saturday.<eos>"""
input_ids = tokenizer(text, return_tensors="pt")['input_ids'].to('cuda')
labels = input_ids.clone()
labels[:, :-1] = labels[:, 1:].clone()
labels[:, -1] = -100  # We use -100 to mask the loss at the last position

with torch.no_grad():
    outputs = model(input_ids, labels=labels)
    loss = outputs.loss

loss.item() # loss is -> 38.96431350708008

cc: @danielhanchen

kdcyberdude commented 2 months ago

Hi @danielhanchen, Isn't it better to use group_by_length to group short sequences if packing is disabled?

danielhanchen commented 2 months ago

@kdcyberdude Oh you're not supposed to do labels[:, :-1] = labels[:, 1:].clone() --> Gemma and Unsloth internally already does that, so you're doing it twice now. Yes group by length is good

@smartliuhw Sorry currently we don't have multi GPU - it'll come in a future release!

kdcyberdude commented 2 months ago

Got it @danielhanchen, it does reduce the loss, but it's still 15.05. way more than what it supposed to be.

smartliuhw commented 2 months ago

@danielhanchen Looking forward to it!

@kdcyberdude I have tried to set packing=False, got the similar result with you. If you have any progress on that, could you please tell me how to solve it, I would really appreciate that❤️

kdcyberdude commented 2 months ago

Hey @smartliuhw, I was able to obtain the following loss curve. It starts at 2.5, increases to 6, and then converges while fluctuating a lot. I'm utilizing the openchat variant of Gemma fine-tuning for low-resource languages, specifically Panjabi, with a batch size of 7*4=28 and a learning rate of 1e-5. You can anticipate better performance when using English datasets.

image

The gradients are high as well - image

smartliuhw commented 2 months ago

Hi @kdcyberdude, thanks a lot for your sharing! I will try to use the variant you mentioned, maybe there's just some problem with Gemma-7b base😂. I have checked their hf homepage and found lots of discussions unsolved

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.