Open antonioalegria opened 6 months ago
Hey, seems related to this issue,also make sure the model is in fp16 or bfloat16.
Training Args has fp16=True. But I have tried after the training to convert model to torch.float16 as well as every module inside it, doesn't make a difference, I get the same error. Haven't found any workaround that works.
I have the same issue (generating with FA2). I checked the input to FA2 module and they are all fp16. In my case using phi-2
, I tracked down the issue to the flash attention code in modeling:
# before >> q:torch.bfloat16, k:torch.bfloat16, v:torch.bfloat16
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# after q:torch.bfloat16, k:torch.float32, v:torch.float32
This past_key_value.update
somehow upcasts the KVs to fp32.
Hey @ArthurZucker !
I'm facing the same issue! Seems like it's coming from using DynamicCache
and autocast
, somewhere along the way the past key values stored in the Cache
get upcasted. It's probably happening where @imirzadeh identified above.
I don't see clear way of dealing with this, as I'm not an autocast expert, but an ugly fix could be add to the condition key_values.dtype == torch.float32
here :
https://github.com/huggingface/transformers/blob/60dea593edd0b94ee15dc3917900b26e3acfbbee/src/transformers/models/mistral/modeling_mistral.py#L431-L449
Let me know what you think and how if I can help furthermore
cc @gante and @zucchini-nlp that's something we should pay attention to!
@antonioalegria In the case of this code snippet the model is loaded in float32
, and I believe running a trainer.train() does not change the model itself to fp16 dtype. Can confirm that running your script fails for me, but the below script which loads model in fp16 works fine
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
checkpoint = "trainer_ckpt_path"
model_kwargs = {
"attn_implementation": "flash_attention_2",
"torch_dtype": torch.float16, # need to indicate fp16 here to work with FA2
}
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")
test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out))
@antonioalegria as @zucchini-nlp wrote :) The fp16=True
flag turns on autocast for training, so it doesn't crash at train time. However, at inference time, there is no autocast, so the model is still seen as FP32. FA2 does not support FP32, and thus it will crash.
You can also train in FP32 (with autocast) and then manually cast to FP16 / reload the trained model in FP16.
I got the same problem when running model inference with flashattention2. Below is the code to reproduce the error @zucchini-nlp @gante @ArthurZucker :
from datasets import load_dataset, Dataset
import torch
checkpoint = "codellama/CodeLlama-7b-hf"
model_kwargs = {
"attn_implementation": "flash_attention_2",
"torch_dtype": torch.bfloat16,
}
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")
# Create a small dataset with 5 sentences
texts = [
"The quick brown fox jumps over the lazy dog.",
"I love eating pizza and watching movies.",
"The sun is shining brightly today.",
"She plays the piano beautifully.",
"He enjoys reading books in his free time."
]
encodings = tokenizer(texts, padding=True, return_tensors="pt")
from datasets import Dataset
# Create the dataset
dataset = Dataset.from_dict({
"input_ids": encodings.input_ids,
"attention_mask": encodings.attention_mask,
"labels": encodings.input_ids
})
# Define the training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="output",
bf16=True, # add this leads to error
per_device_eval_batch_size=1,
predict_with_generate=True,
generation_max_length=50,
generation_num_beams=1,
)
from transformers import Seq2SeqTrainer
# Create the Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
)
# Run prediction on the dataset
test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out)) # Okay
predictions = trainer.predict(dataset) # Error
print(tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True))
If I set bf16=True in the training arguments, the trainer.predict will crash but model.generate is okay!
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: query and key must have the same dtype
It works fine if I remove bf16 in the argument... The problem is that I need to use bf16, otherwise the inference speed is very slow..
@edchengg
I could localize the error. When using trainer it has a line that prepares model with accelerate
, which in turn adds something like model.forward = convert_outputs_to_fp32(new_forward)
casting all model outputs to fp32. That is why the past _key_values
cache was casted to fp32 causing errors when generating with kv cache.
@gante, I think the issue will be resolved when we get rid of the legacy cache format, so I will leave to it you here :) UPD: Or maybe not, just found this issue which causes error if cache class is used in prediction loop
Same here, got it running with fp16
but bf16
raises the error on a A100 cluster and the Idefics2 Model !
Good luck with the fix !
I'm also running into this error but only when quantising the model. I made an edit much like the one referenced here by @ylacombe
Hey @ArthurZucker ! I'm facing the same issue! Seems like it's coming from using
DynamicCache
andautocast
, somewhere along the way the past key values stored in theCache
get upcasted. It's probably happening where @imirzadeh identified above.I don't see clear way of dealing with this, as I'm not an autocast expert, but an ugly fix could be add to the condition
key_values.dtype == torch.float32
here :Let me know what you think and how if I can help furthermore
but that simply ensures each of the three states have the same dtype:
input_dtype = query_states.dtype
if input_dtype == torch.float32 or not (query_states.dtype == key_states.dtype == value_states.dtype):
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
@zucchini-nlp are we juste waiting for the deprecation of the to legacy cache func?
@ArthurZucker sorry, totally forgot to reply here. While I was thinking that we might need to pass in "compute_dtype" in all cache classes, I found that #30536 actually should solve the problem.
In the linked PR we never go back and forth between legacy-new cache formats, thus no problems with incorrect casting to fp32. Will verify when the PR is merged, and close this issue :)
Update: The linked PR was merged and I tested that the dtype mismatch error doesn't appear anymore when doing inference with Trainer. I used one of the above scripts for testing. All new LLMs can be used for inference with flash-attn in Trainer!
But this works only for the models that support new cache format, which excludes some relatively older models like OPT or GPT2. I will keep the issue open and extend new cache format to other decoder-only models. cc @gante, not sure if there is a tracker for making models compatible with DynamicCache 🤔
This is the script used for verifying that it works:
from datasets import load_dataset, Dataset
import torch
checkpoint = " meta-llama/Llama-2-7b-chat-hf" # facebook/opt-125M
model_kwargs = {
"attn_implementation": "flash_attention_2",
"torch_dtype": torch.bfloat16,
}
from transformers import AutoTokenizer, AutoModelForCausalLM, Seq2SeqTrainingArguments
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")
# Create a small dataset with 5 sentences
texts = [
"The quick brown fox jumps over the lazy dog.",
"I love eating pizza and watching movies.",
"The sun is shining brightly today.",
"She plays the piano beautifully.",
"He enjoys reading books in his free time."
]
encodings = tokenizer(texts, padding=True, return_tensors="pt")
from datasets import Dataset
# Create the dataset
dataset = Dataset.from_dict({
"input_ids": encodings.input_ids,
"attention_mask": encodings.attention_mask,
"labels": encodings.input_ids
})
# Define the training arguments
training_args = Seq2SeqTrainingArguments(
output_dir="output",
bf16=True, # add this leads to error
per_device_eval_batch_size=1,
predict_with_generate=True,
generation_max_length=50,
generation_num_beams=1,
)
from transformers import Seq2SeqTrainer
# Create the Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
)
# Run prediction on the dataset
test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out)) # Okay
predictions = trainer.predict(dataset) # Error (NOT ANYMORE)
print(tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True))
System Info
transformers
version: 4.38.2Who can help?
@gante @muellerz @pacman100
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Here is the exception:
I don't think this is an SFTTrainer problem, which is why I didn't report there.
Expected behavior
It should generate without any problems. Otherwise, it should give an error or indication of what needs to be done so that the model can be used for generation.