unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.37k stars 1.28k forks source link

Fix orpo/dpo trainer #1286

Open dame-cell opened 1 week ago

dame-cell commented 1 week ago

This draft is a temporary fix to this issue 1285

Since the latest version of trl 0.12.0 now takes in processing_class instead of tokenizer So, we need to change

from transformers import TrainingArguments
from trl import DPOTrainer, DPOConfig
from unsloth import is_bfloat16_supported

# the newest version of trl now uses processing_class instead of tokenizer

dpo_trainer = DPOTrainer(
    model=model,
    ref_model=None,
    args=DPOConfig(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_ratio=0.1,
        num_train_epochs=3,
        learning_rate=5e-6,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.0,
        lr_scheduler_type="linear",
        seed=42,
        output_dir="outputs",
        report_to="none",  # Use this for WandB etc.
    ),
    beta=0.1,
    train_dataset=raw_datasets["train"],
    #tokenizer=tokenizer,
    processing_class=tokenizer, 
    max_length=1024,
    max_prompt_length=512,
)

And for some reason the unsloth FastLanguageModel.from_pretrained tokenizer does not work well with the processsing_class so we need to import original tokenizer

## for the DPO colab 
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("unsloth/zephyr-sft-bnb-4bit")
## For the ORPO colab notebook 
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-bnb-4bit")

instead of

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = model_name,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

@danielhanchen