ROCm / bitsandbytes

8-bit CUDA functions for PyTorch
MIT License
34 stars 3 forks source link

Exception: cublasLt ran into an error during fine-tuning LLM in 8bit mode on AMD MI300x GPUs #40

Open jerin-scalers-ai opened 3 months ago

jerin-scalers-ai commented 3 months ago

System Info

Reproduction

Issue running multiGPU finetuning with INT8 model precision AMD MI300x GPUs. Finetuning is running on single GPU thou with INT8.

Issue: Exception: cublasLt ran into an error!

from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from transformers import AutoTokenizer, pipeline, TrainingArguments, Trainer
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_config,
    get_peft_model_state_dict,
    prepare_model_for_kbit_training,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    device_map="auto",
    torch_dtype=torch.float32, #Values: torch.bfloat16, torch.float16
    load_in_8bit=False, # True for INT8
    load_in_4bit=False, # True for FP4
) 
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf"
)
model = prepare_model_for_kbit_training(model)
peft_model = get_peft_model(model, LoraConfig)
train_data, val_data = get_dataset()

training_arguments = TrainingArguments(
    per_gpu_train_batch_size=16,# value: Batch size
    warmup_steps=10,
    max_steps=200,
    fp16=False, # True for FP16,FP4,INT8 
    bf16=False, # True for BF16
)

class MetricsClass(TrainerCallback):
    def on_train_begin(...):
        self.step_count = 0
        self.total_step_time = 0
    def on_step_begin(...):
        self.start_time = time.time()
    def on_step_end(...):
        self.end_time = time.time()
        self.total_step_time = (self.end_time - self.start_time)
        self.step_count += 1
    def get_metrics():
        return self.step_count / self.total_step_time

metcl = MetricsClass()

trainer = Trainer(
    model=peft_model,
    train_dataset=train_data,
    eval_dataset=val_data,
    args=training_arguments
    callbacks=[metcl]
)

trainer.train()
metcl.get_metrics()

Expected behavior

No error.