unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
17.87k stars 1.24k forks source link

Discrepancy in LLaMA 3.1 performance when using custom trainer and SFTTrainer #1025

Closed salokr closed 1 month ago

salokr commented 1 month ago

Hi,

Thank you for availing the library.

I am using LLaMA-3.1 (unsloth/llama-3-8b-Instruct-bnb-4bit version) to instruct-tune a model on GSM8K. First, I used the SFTTrainer to train a model and then defined a custom metric using compute_metrics. Next, I pass the logits to an argmax function and then do batch_decode to get the output. Unfortunately, this produces gibberish output such as:

INPUT SENTENCE

<|im_start|>user
Janet buys a multi-flavor pack of cheese sticks. 15 of the sticks are cheddar, 30 are mozzarella, and 45 are pepperjack. If Janet picks a cheese stick at random, what is the percentage chance it will be pepperjack?
<|im_start|>assistant

OUTPUT SENTENCE

Questiontemplate agination|> <:Imy to goate his house he He He are 3 rooms,  bedroom has 2 days to paintate.  There living takes 6 hours more than the bedroom,  The living room takes  as long time as the else..  How many did the take to|The|im_end|>user
LetassistantLet

As the next step, I used the model.generate function and then model started generating not only sensible but correct answers too: INPUT SENTENCE

 ['<|im_start|>user\nJanet buys a multi-flavor pack of cheese sticks. 15 of the sticks are cheddar, 30 are mozzarella, and 45 are pepperjack. If Janet picks a cheese stick at random, what is the percentage chance it will be pepperjack?\n<|im_start|>assistant\n\n']

OUTPUT SENTENCE

["<|im_start|>user\nJanet buys a multi-flavor pack of cheese sticks. 15 of the sticks are cheddar, 30 are mozzarella, and 45 are pepperjack. If Janet picks a cheese stick at random, what is the percentage chance it will be pepperjack?\n<|im_start|>assistant\n\nLet's calculate the percentage chance it will be pepperjack.\n\nThere are a total of 15 + 30 + 45 = 90 cheese sticks.\n\nThe number of pepperjack cheese sticks is 45.\n\nTo find the percentage chance, divide the number of pepperjack cheese sticks by the total number of cheese sticks and multiply by 100:\n\n(45 / 90) x 100 = 50%\n\nSo, the percentage chance of Janet picking a pepperjack cheese stick is 50%."]

To reproduce this, I am posting the code below:

Using custom_metric

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=val_dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
dataset_num_proc=2,
packing=False,
args=Seq2SeqTrainingArguments(
fp16_full_eval=True,
eval_accumulation_steps = 4,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
learning_rate=args.learning_rate,
lr_scheduler_type="linear",
num_train_epochs=100,  # Use args or dynamic based on early stopping
evaluation_strategy="steps",#steps or epoch
save_strategy = "steps",
eval_steps=1,
save_steps=1,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
output_dir=f"checkpoints/{args.dataset_name}",
logging_steps=4,
optim="adamw_8bit",
# report_to=["wandb"],  # Or tensorboard
logging_dir="logs",
weight_decay=0.01,
warmup_steps = 500,
seed = 1337,
load_best_model_at_end = True,
predict_with_generate=True,
save_total_limit=5,
include_inputs_for_metrics = True,
metric_for_best_model = "eval_accuracy"
),
compute_metrics = compute_metrics_wrapper(val_dataset, tokenizer),
callbacks=[
EarlyStoppingCallback(early_stopping_patience=args.patience),
# WandbCallback()  # Add more callbacks if needed
]
)
trainer.train()

the compute_metrics function:

def compute_metrics_wrapper(real_inputs, tokenizer, key = "val"):
def compute_metrics(eval_preds):
print("-*"*100)
logs = []
# eval_preds contains predictions and references (labels)
# print(">>>>>", eval_preds, dir(eval_preds))
predictions, labels, inputs = eval_preds.predictions, eval_preds.label_ids, eval_preds.inputs 
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
print("Check lengths; Hint <--> All should be equal ---> ", len(real_inputs), len(predictions), len(labels), len(inputs))
if isinstance(predictions, torch.Tensor):  # Ensure predictions are in tensor format
predictions = torch.argmax(predictions, dim=-1).cpu().numpy()
elif isinstance(predictions, np.ndarray):
# If it's a numpy array, also apply argmax
predictions = np.argmax(predictions, axis=-1)
elif isinstance(predictions, list):
# If it's a list, convert to numpy array first
predictions = np.array(predictions)
predictions = np.argmax(predictions, axis=-1)
# print(predictions[0])
# print(">>>labels", labels[0])
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
# print(decoded_preds[0], flush = True)
# print("decoded_labels>>>>", decoded_labels[0], flush = True)
correct_count = 0
for pred, labl, real_input in tqdm(zip(decoded_preds, decoded_labels, real_inputs), total = len(decoded_labels), desc = "Calculating Metrics; Hold on ..."):
pred_num = extract_last_number_gsm8k(pred)
labl_num = extract_last_number_gsm8k(labl)
is_correct = int(pred_num == labl_num)
correct_count += is_correct
# logs.append({"input": real_input["text"], "prediction": pred, "label": labl, "is_correct": is_correct, "extracted_num_pred":pred_num, "extracted_num_labl": labl_num})
logs.append([real_input["text"], pred, labl, is_correct, pred_num, labl_num])
print(f"Pred: {pred}")
print(f"Labl: {labl}")
print(f"Real Input: {real_input}")
print("----------------------------------------------------------------------------------------------")
accuracy = correct_count/len(real_inputs)
wandb.log({
f"{key}_accuracy":accuracy,
"predictions": wandb.Table(columns = ["Input", "Prediction", "Label", "is_correct", "extracted_num_pred", "extracted_num_labl"], data = logs)
})
print("*-"*100)
return {"accuracy": accuracy}
return compute_metrics

But, When I replace the SFTTrainer with my custom trainer that calls model.generate

from LLaMATrainer import LLaMATrainer
trainer = LLaMATrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=val_dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
dataset_num_proc=2,
packing=False,
args=Seq2SeqTrainingArguments(
fp16_full_eval=True,
eval_accumulation_steps = 4,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
learning_rate=args.learning_rate,
lr_scheduler_type="linear",
num_train_epochs=100,  # Use args or dynamic based on early stopping
evaluation_strategy="steps",#steps or epoch
save_strategy = "steps",
eval_steps=1,
save_steps=1,
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
output_dir=f"checkpoints/{args.dataset_name}",
logging_steps=4,
optim="adamw_8bit",
# report_to=["wandb"],  # Or tensorboard
logging_dir="logs",
weight_decay=0.01,
warmup_steps = 500,
seed = 1337,
load_best_model_at_end = True,
predict_with_generate=True,
save_total_limit=5,
include_inputs_for_metrics = True,
metric_for_best_model = "eval_accuracy"
),
compute_metrics = compute_metrics_wrapper(val_dataset, tokenizer),
callbacks=[
EarlyStoppingCallback(early_stopping_patience=args.patience),
# WandbCallback()  # Add more callbacks if needed
]
)
# Start Training
trainer.max_seq_length=args.max_seq_length
trainer.gt_val_data = val_dataset
trainer.gt_test_data = test_dataset
trainer.dataset_name = args.dataset_name
logger.info("Starting model training...")
trainer.train()

LLaMATrainer is defined as:


class LLaMATrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# print(">>", args)
# print(kwargs, ">>")
def save_output(self, output_file, outputs):
with open(output_file, 'w') as f:
json.dump(outputs, f, indent=4)
def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
# print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>", self.compute_metrics)
# print(dir(self.tokenizer), "<<<<<")
# print(">>>>", self.tokenizer.model_max_length)
# print("<<<<", self.max_seq_length)
max_seq_length = self.max_seq_length[0] if type(self.max_seq_length) == type(()) else self.max_seq_length
logger.info("Running evaluation loop...")
FastLanguageModel.for_inference(self.model) # Enable native 2x faster inference
# Call the parent method to retain original functionality if needed
# output = super().evaluate(eval_dataset, ignore_keys, metric_key_prefix) # we don't use that because (see below!)
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
outputs = []
all_decoder_ids = []
picked_dataset = self.gt_val_data if metric_key_prefix == "eval" else self.gt_test_data
# print(len(eval_dataset), len(picked_dataset))
# print(self.gt_val_data[0])
current_epoch = int(self.state.epoch) if self.state.epoch else "final"
output_dir = self.args.output_dir
output_file = os.path.join(output_dir, f"{self.dataset_name}_val_results_epoch_{current_epoch}.json")
os.makedirs(output_dir, exist_ok=True)
for idx, eval_set in tqdm(enumerate(eval_dataset), total = len(eval_dataset), desc = f"Hold on :D doing {metric_key_prefix} ..."):
# print(">>>>>>>>>>>>>>>>>>>>", eval_set)
input_ids = torch.tensor(eval_set['input_ids']).unsqueeze(0).to(self.model.device)
# generated_text = tokenizer.decode(model.generate(input_ids, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id)[0][input_ids.shape[-1]:], skip_special_tokens=True)
decoder_ids = self.model.generate(input_ids,max_new_tokens = 1000, pad_token_id = self.tokenizer.eos_token_id)
decoded_output = self.tokenizer.batch_decode(decoder_ids, skip_special_tokens = True)
###
input_sentence = self.tokenizer.batch_decode(input_ids, skip_special_tokens = True)
###
ground_truth = picked_dataset[idx]["conversations"][-1]["value"]
print(">>>> INPUT SENTENCE", input_sentence, "<<<")
print(">>>> DECODED OUTPUT", decoded_output, "<<<")
print(">>>> Ground Truth", ground_truth, "<<<")
outputs.append({"Input": input_sentence, "Prediction": decoded_output[0], "Label": ground_truth})
self.save_output(output_file, outputs)
all_decoder_ids.extend(decoder_ids)
# print(all_decoder_ids)
# Add custom evaluation logic here
# For example, log custom metrics, change evaluation logic, etc.
# print(json.dumps(outputs, indent = 4))
logger.info("Custom evaluation complete")
# output_file = os.path.join(output_dir, f"{self.args.dataset_name}_val_results_epoch_{current_epoch}.json")
# with open(output_file, 'w') as f:
#     json.dump(outputs, f, indent=4)
logger.info(f"Saved evaluation results to {output_file}")
# return {"lol":40}
# print(outputs)
# metrics = {f"{metric_key_prefix}_accuracy":.44}
eval_pred = EvalPrediction(predictions=outputs, label_ids=all_decoder_ids)
output_metrics = self.compute_metrics(eval_pred)
for key in list(output_metrics.keys()):
if not key.startswith(f"{metric_key_prefix}_"):
output_metrics[f"{metric_key_prefix}_{key}"] = output_metrics.pop(key)
    self.log(output_metrics)
    return output_metrics

> Other important functions:

**map_dataset_to_template**

def map_dataset_to_template(dataset, tokenizer, key = "conversations"): def apply_template(examples, tokenizer, key): messages = examples[key] text = [tokenizer.apply_chat_template(message, tokenize = False, add_generation_prompt = False) for message in messages] return {"text":text}

Log the number of examples processed

logger.info(f"Original dataset size: {len(dataset)}")
processed_dataset = dataset.map(partial(apply_template, tokenizer=tokenizer, key=key), batched=True)
logger.info(f"Processed dataset size: {len(processed_dataset)}")
return processed_dataset#dataset.map(partial(apply_template, tokenizer=tokenizer, key=key), batched=True)


I agree that save_Steps and eval_steps = 2 might not be enough to get a decent output but the parameters are the same in both cases and the generate function generates some sensible output. 
danielhanchen commented 1 month ago

I'm assuming a FastLanguageModel.for_inference(model) call was missing maybe

salokr commented 1 month ago

Hi, thank you for responding. I added the call and still there was no difference

I have tried to write a minimal script to reproduce the errors (since I had to provide custom Trainer and custom evaluator the code is still large; sorry for that).

The code is available at: https://drive.google.com/file/d/1mkNf8zWSr82KIpcP-VaD75z5rM8fsJ16/view?usp=sharing

If you want to skip over the notebook, please only look at LLaMATrainer and compute_metrics_wrapper The rest of the code is the same and there's no major alteration (the same input/output template).

I also printed the logits and due to large output space, I have kept them hidden, click on "Show Output" to take a look at them

salokr commented 1 month ago

I minimized the code even more. Here's the minimal script: (also attached a screenshot)

from unsloth import FastLanguageModel
import torch
print(torch.cuda.device_count())
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

from datasets import load_dataset
dataset_gsm8k = load_dataset("gsm8k", "main")

sample = dataset_gsm8k["train"][0]
# print(sample)
encoding = tokenizer(sample["question"], return_tensors="pt")
# print(encoding["input_ids"][0].shape)
logits = model(encoding["input_ids"]).logits
import torch
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predictions = torch.argmax(probabilities, axis=-1)
print("Predictions:", predictions)
print("Tokenized Input:", tokenizer.batch_decode(predictions))

FastLanguageModel.for_inference(model)
op_gen = model.generate(encoding["input_ids"], return_dict_in_generate=True,output_scores=True)
tokenizer.batch_decode(op_gen["sequences"])

Screenshot 2024-09-17 at 2 58 45 PM

salokr commented 1 month ago

NVM. I believe I was using the forward pass only once without using the BoS token.

diazr04 commented 1 month ago

Hello I am having the same issue as you, I do not know how you solved it.