unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
17.94k stars 1.24k forks source link

unsloth and trl weights are inconsistent #580

Closed MachineGunLin closed 5 months ago

MachineGunLin commented 5 months ago

I use this code to fine-tune Llama-3-8B-Instruct (with unsloth):

from unsloth import FastLanguageModel
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
from transformers import set_seed as transformers_set_seed

max_seq_length = 1024 # Supports RoPE Scaling interally, so choose any!
# Get LAION dataset
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train")

transformers_set_seed(3407)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "/data/Meta-Llama-3-8B-Instruct",
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = False,
    attn_implementation="flash_attention_2",
)

# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    # use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    use_gradient_checkpointing = False, # True or "unsloth" for very long context
    random_state = 3407,
    max_seq_length = max_seq_length,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    tokenizer = tokenizer,
    args = TrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1,
        warmup_steps = 0,
        max_steps = 2,
        save_steps = 2,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        output_dir = "outputs_unsloth",
        optim = "adamw_8bit",
        seed = 3407,
    ),
)
trainer.train()

For comparison, I used the following code to fine-tune Llama 3-8B-Instruct using trl:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig, TaskType, get_peft_model

from transformers import set_seed as transformers_set_seed

model_name = "/data/Meta-Llama-3-8B-Instruct"
max_seq_length = 1024 # Supports RoPE Scaling interally, so choose any!
# Get LAION dataset
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train")

transformers_set_seed(3407)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map = "auto",
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    model_max_length = max_seq_length,
    attn_implementation="flash_attention_2",
)
tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig(
    r=16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer = tokenizer,
    args = TrainingArguments(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 1,
        warmup_steps = 0,
        max_steps = 2,
        save_steps = 2,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        output_dir = "outputs_hf",
        optim = "adamw_8bit",
        seed = 3407,
    ),
)

trainer.train()

Then, I compared whether the weights of the two methods are consistent after two steps:

from transformers import AutoModelForCausalLM
import torch

model_hf = AutoModelForCausalLM.from_pretrained("outputs_hf/checkpoint-2")
model_unsloth = AutoModelForCausalLM.from_pretrained("outputs_unsloth/checkpoint-2")

state_dict_hf = model_hf.state_dict()
state_dict_unsloth = model_unsloth.state_dict()

total_num_of_params = len(state_dict_hf)
num_of_diff_params = 0

for key in state_dict_hf:
    if key not in state_dict_unsloth:
        print("!" * 50)
        print(f"{key} not found in state_dict_unsloth")
        print("!" * 50)

    param_hf = state_dict_hf[key]
    param_unsloth = state_dict_unsloth[key]

    tolerance = 1e-3
    if not torch.allclose(param_hf, param_unsloth, atol=tolerance):
        num_of_diff_params += 1
        print(f"Parameters {key} differ more than {tolerance}")

print(f"total_num_of_params: {total_num_of_params}")
print(f"num_of_diff_params: {num_of_diff_params}")

The results of Meta-Llama 3-8B-Instruct are as follows:

Parameters model.layers.26.mlp.up_proj.lora_B.default.weight differ more than 0.001
Parameters model.layers.29.self_attn.q_proj.lora_B.default.weight differ more than 0.001
Parameters model.layers.30.mlp.gate_proj.lora_B.default.weight differ more than 0.001
Parameters model.layers.30.mlp.up_proj.lora_B.default.weight differ more than 0.001
Parameters model.layers.31.mlp.up_proj.lora_B.default.weight differ more than 0.001
total_num_of_params: 739
num_of_diff_params: 5

After only two steps, the weight of unsloth and trl is more than 1e-3 (5 out of 739 weights), may I ask what might be the cause?

MachineGunLin commented 5 months ago

If I set atol to 1e-4, the result will be 224 / 739, I don't know if this is within the normal margin of error

MachineGunLin commented 5 months ago

You mentioned that one of the key features of unsloth is "no approximation methods - all exact", so the weights should be the same. I don't understand what's wrong with my code

danielhanchen commented 5 months ago

Oh this is entirely normal - due to different internal upcasting in Triton kernels and downcasting, you will see some small fractional differences. Sometimes you might see a little lower loss or tinnily higher loss, it all depends.

No approximation and all exact indeed, (ie not approximating attention etc). You will see some precision differences, but it's entirely normal :)

I normally compare the training losses, and see if mostly they match - that's the most important part.

x6p2n9q8a4 commented 5 months ago

If I set atol to 1e-4, the result will be 224 / 739, I don't know if this is within the normal margin of error

Hi,

Did you find when running the same training procedure using unsloth, for example, " fine-tune Llama-3-8B-Instruct (with unsloth)", the training loss will be different and the performance on test dataset is also inconsistent. It is too strange.

danielhanchen commented 5 months ago

@x6p2n9q8a4 What do you mean "different" training losses? Can you provide the example you did with Unsloth and the other with normal HF? I can take a look to see if there are issues