huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.84k stars 1.24k forks source link

A bug when SFT mistral-7B using DataCollatorForCompletionOnlyLM #2192

Open smartliuhw opened 2 weeks ago

smartliuhw commented 2 weeks ago

System Info

trl-0.8.6 transformers-4.41.2

Information

Tasks

Reproduction

Since there's pad token needed for DataCollatorForCompletionOnlyLM, and there's no pad token for mistral-7B model, I have set the unknown token as the pad token:

tokenizer.pad_token = tokenizer.unk_token

But when I train the model, the loss didn't drop while increased (from 1.x to 5.x) after trained for 2 epochs. The loss dropped well when I'm using the same coding to SFT Gemma-7B model (use the Gemma's own pad token). Is there any other setting I missed? Or do you have any other suggestions for setting the right pad token?

Expected behavior

The loss drops normally

alvarobartt commented 2 weeks ago

Hi here @smartliuhw! When running SFT ideally you should set the tokenizer's PAD token to the EOS (End of Sentence) token, not to an unknown token; this implies that the sequences will be padded using the EOS token instead, unless the tokenizer already has an explicitly defined padding token; which is fine too.

So the following should do the work:

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

[!NOTE] You need to set that right after instantiating the tokenizer i.e. before tokenizing the inputs and, so on, before training too. Also when saving the model make sure that you also save the tokenizer that you used for that training run.

smartliuhw commented 2 weeks ago

@alvarobartt Hi! Thanks for your response! I have tried to set PAD token to the EOS token and ran another task on mistral-7B, but the loss was still abnormal--It started with 1.x, after the warmup it increased to 4.x and kept staying at the level, while Gemma have dropped to 0.02. I'm going to try the same code on Llama3 since there's no PAD token for Llama either. I will let you know if I get new result.

btw, could you please help me check if there's anything wrong with my code

import os
import sys
import random
import json
import argparse
from dataclasses import dataclass, field

import torch
from tqdm import tqdm
import pandas as pd
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from trl.trainer import ConstantLengthDataset
from datasets import load_from_disk, concatenate_datasets, Dataset, DatasetDict
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, HfArgumentParser

from utils import get_train_data, formatting_prompts_func, formatting_constant_length_func, model_name_to_path

@dataclass
class ModelArguments:
    model_type: str = field(default=None, metadata={"help": "The model name."})

@dataclass
class DataArguments:
    train_data: str = field(default=None, metadata={"help": "Choose the training data, split by comma."})

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    max_seq_length: int = field(default=8192, metadata={"help": "The cache directory."})

def train():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Set random seed
    random.seed(training_args.seed)

    # Load model and tokenizer
    model_path = model_name_to_path[model_args.model_type]
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right")
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
    # if "mistral" in model_args.model_type.lower() or "llama" in model_args.model_type.lower():
        # tokenizer.add_special_tokens({"pad_token": "<pad>"})
        # model.resize_token_embeddings(len(tokenizer))
        # tokenizer.pad_token = tokenizer.unk_token
        # tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # Load training data
    print("Loading training data...")
    train_data = get_train_data(data_args.train_data, training_args.seed)

    # set completion only data collator
    response_template = "### Response:"
    collator = DataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer,
        response_template=response_template,
        # max_length=training_args.max_seq_length,
    )

    # Load trainer
    trainer = SFTTrainer(
        model,
        training_args,
        tokenizer=tokenizer,
        train_dataset=train_data,
        formatting_func=formatting_prompts_func,
        # packing=True,
        data_collator=collator,
        max_seq_length=training_args.max_seq_length,
    )

    # Train the model
    trainer.train()

    # Save the model
    trainer.save_model(training_args.output_dir)

if __name__ == "__main__":
    train()
smartliuhw commented 2 weeks ago

Hi @alvarobartt ! I have ran two more experiments on Llama3-8B and Llama2-7B, both works fine in the same code. I think maybe there's some special setting needed for mistral-7B. Do you have any idea about that?

zwhe99 commented 1 week ago

Hi @smartliuhw ! The doc says that pad_token_id should be different from eos_token_id. Do you know what the right way is?

smartliuhw commented 1 week ago

Hi @zwhe99 ! I have done some more exps and got these findings:

Hope those findings can help you out!

zwhe99 commented 1 week ago

Thanks! @smartliuhw