unslothai / unsloth

Finetune Llama 3.1, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
15.28k stars 1.03k forks source link

RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16 #212

Open hcsolakoglu opened 6 months ago

hcsolakoglu commented 6 months ago

In the Gemma 7b notebook, when rslora and dora are active, and the settings for 4-bit and 8-bit are off with r=8 and alpha=16, I encounter an error as described below. I have targeted all linear layers with lora and trained in the alpaca format. This issue does not occur every time. I am uncertain about what triggers it, but it seems that the main cause is the activation of dora. While I believe the training process is fine, I receive the following error during inference via Unsloth. However, when I switch to another inference engine, the error disappears.


RuntimeError Traceback (most recent call last)

in () 12 from transformers import TextStreamer 13 text_streamer = TextStreamer(tokenizer) ---> 14 _ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 512) 20 frames /usr/local/lib/python3.10/dist-packages/peft/tuners/lora/layer.py in _apply_dora(self, x, lora_A, lora_B, scaling, active_adapter) 214 mag_norm_scale = (magnitude / weight_norm).view(1, -1) 215 result_dora = (mag_norm_scale - 1) * ( --> 216 F.linear(x, transpose(weight, self.fan_in_fan_out)) 217 ) + mag_norm_scale * lora_B(lora_A(x)) * scaling 218 RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
danielhanchen commented 6 months ago

@hcsolakoglu Yep it's most likely DoRA - since currently Unsloth doesn't support it technically, so inference is breaking probably

hcsolakoglu commented 6 months ago

@hcsolakoglu Yep it's most likely DoRA - since currently Unsloth doesn't support it technically, so inference is breaking probably

Training with Dora is compatible with Unsloth, correct? I assume that only the inference phase has this problem?

danielhanchen commented 6 months ago

@hcsolakoglu Hmm yes it's compatible, but ye its just the inference phase - I'll probs add a check

lakshmid13579 commented 3 months ago

Same issue with DPO: max_seq_length = 1024 # Supports automatic RoPE Scaling, so choose any number. compute_dtype = getattr(torch, "float16")

Load model

model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/mistral-7b-v0.3",

torch_dtype= compute_dtype,

max_seq_length = max_seq_length,
device_map="auto",
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False.
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf

)

Do model patching and add fast LoRA weights

model = FastLanguageModel.get_peft_model( model, r = 16, target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj",], lora_alpha = 16, lora_dropout = 0, # Dropout = 0 is currently optimized bias = "none", # Bias = "none" is currently optimized use_gradient_checkpointing = True, random_state = 3407, )

training_args = DPOConfig( output_dir="/opt/data/enrichment/StructuredInformationExtraction/models/dpo/", beta=0.1, run_name = "dpo", max_length= 1024, max_prompt_length=512, per_device_train_batch_size=4, gradient_accumulation_steps=1, gradient_checkpointing=True, learning_rate=6e-6, num_train_epochs=1, report_to="wandb",

)

dpo_trainer = DPOTrainer( model, ref_model=None, args=training_args, train_dataset=train_ds, tokenizer=tokenizer, ) dpo_trainer.train()

danielhanchen commented 3 months ago

Wait you forgot to set fp16 = True or bf16 = True in the trainer

brthor commented 2 months ago

I had a similar error training codellama, and adding bf16=True to SFTConfig (I am using SFTTrainer) solved it.

A check for this would be helpful. The exact same code ran well before swapping out a PEFTModel with the unsloth equivalent.

danielhanchen commented 2 months ago

Apologies on the delay! Yes an auto marker is what I was hoping for - I'll see what I can do