huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.67k stars 26.7k forks source link

Model trained with Flash Attention 2.0 raises "RuntimeError: query and key must have the same dtype" when generating #30019

Open antonioalegria opened 6 months ago

antonioalegria commented 6 months ago

System Info

Who can help?

@gante @muellerz @pacman100

Information

Tasks

Reproduction

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("imdb")

checkpoint = "facebook/opt-125M"
model_kwargs = {
    "attn_implementation": "flash_attention_2",
}

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

training_args = TrainingArguments(
    output_dir="./test_trainer",
    evaluation_strategy="epoch",
    report_to="none",
    fp16=True,
    save_strategy="epoch",
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    dataset_text_field="text",
)

trainer.train() # works
trainer.evaluate() # works

test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to("cuda")
model.generate(test_input.input_ids, max_length=100, num_return_sequences=1) # Blows up!

Here is the exception:

Traceback (most recent call last):
  File "/workspace/storage/hyperml/notebooks/test_trainer/x.py", line 43, in <module>
    model.generate(test_input.input_ids, max_length=100, num_return_sequences=1) # Blows up!
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/transformers/generation/utils.py", line 1544, in generate
    return self.greedy_search(
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/transformers/generation/utils.py", line 2404, in greedy_search
    outputs = self(
              ^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/accelerate/utils/operations.py", line 822, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/accelerate/utils/operations.py", line 810, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/transformers/models/opt/modeling_opt.py", line 1145, in forward
    outputs = self.model.decoder(
              ^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/transformers/models/opt/modeling_opt.py", line 911, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/transformers/models/opt/modeling_opt.py", line 552, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/transformers/models/opt/modeling_opt.py", line 384, in forward
    attn_output = self._flash_attention_forward(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/transformers/models/opt/modeling_opt.py", line 450, in _flash_attention_forward
    attn_output = flash_attn_func(
                  ^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 825, in flash_attn_func
    return FlashAttnFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 507, in forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/storage/env/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 51, in _flash_attn_forward
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
                                                                ^^^^^^^^^^^^^^^^^^^^
RuntimeError: query and key must have the same dtype

I don't think this is an SFTTrainer problem, which is why I didn't report there.

Expected behavior

It should generate without any problems. Otherwise, it should give an error or indication of what needs to be done so that the model can be used for generation.

ArthurZucker commented 6 months ago

Hey, seems related to this issue,also make sure the model is in fp16 or bfloat16.

antonioalegria commented 6 months ago

Training Args has fp16=True. But I have tried after the training to convert model to torch.float16 as well as every module inside it, doesn't make a difference, I get the same error. Haven't found any workaround that works.

imirzadeh commented 6 months ago

I have the same issue (generating with FA2). I checked the input to FA2 module and they are all fp16. In my case using phi-2, I tracked down the issue to the flash attention code in modeling:

# before >>  q:torch.bfloat16, k:torch.bfloat16, v:torch.bfloat16
if past_key_value is not None:
   cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
   key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# after q:torch.bfloat16, k:torch.float32, v:torch.float32

This past_key_value.update somehow upcasts the KVs to fp32.

ylacombe commented 6 months ago

Hey @ArthurZucker ! I'm facing the same issue! Seems like it's coming from using DynamicCache and autocast, somewhere along the way the past key values stored in the Cache get upcasted. It's probably happening where @imirzadeh identified above.

I don't see clear way of dealing with this, as I'm not an autocast expert, but an ugly fix could be add to the condition key_values.dtype == torch.float32 here : https://github.com/huggingface/transformers/blob/60dea593edd0b94ee15dc3917900b26e3acfbbee/src/transformers/models/mistral/modeling_mistral.py#L431-L449

Let me know what you think and how if I can help furthermore

ArthurZucker commented 6 months ago

cc @gante and @zucchini-nlp that's something we should pay attention to!

zucchini-nlp commented 6 months ago

@antonioalegria In the case of this code snippet the model is loaded in float32, and I believe running a trainer.train() does not change the model itself to fp16 dtype. Can confirm that running your script fails for me, but the below script which loads model in fp16 works fine

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

checkpoint = "trainer_ckpt_path"
model_kwargs = {
    "attn_implementation": "flash_attention_2",
    "torch_dtype": torch.float16,  # need to indicate fp16 here to work with FA2
}

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")

test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out))
gante commented 5 months ago

@antonioalegria as @zucchini-nlp wrote :) The fp16=True flag turns on autocast for training, so it doesn't crash at train time. However, at inference time, there is no autocast, so the model is still seen as FP32. FA2 does not support FP32, and thus it will crash.

You can also train in FP32 (with autocast) and then manually cast to FP16 / reload the trained model in FP16.

edchengg commented 5 months ago

I got the same problem when running model inference with flashattention2. Below is the code to reproduce the error @zucchini-nlp @gante @ArthurZucker :

from datasets import load_dataset, Dataset
import torch

checkpoint = "codellama/CodeLlama-7b-hf"
model_kwargs = {
    "attn_implementation": "flash_attention_2",
    "torch_dtype": torch.bfloat16, 
}
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")

# Create a small dataset with 5 sentences
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "I love eating pizza and watching movies.",
    "The sun is shining brightly today.",
    "She plays the piano beautifully.",
    "He enjoys reading books in his free time."
]

encodings = tokenizer(texts, padding=True, return_tensors="pt")
from datasets import Dataset
# Create the dataset
dataset = Dataset.from_dict({
    "input_ids": encodings.input_ids,
    "attention_mask": encodings.attention_mask,
    "labels": encodings.input_ids
})

# Define the training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="output",
    bf16=True, # add this leads to error
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    generation_max_length=50,
    generation_num_beams=1,
)
from transformers import Seq2SeqTrainer
# Create the Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
)

# Run prediction on the dataset

test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out)) # Okay

predictions = trainer.predict(dataset) # Error
print(tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True))

If I set bf16=True in the training arguments, the trainer.predict will crash but model.generate is okay!

    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
RuntimeError: query and key must have the same dtype

It works fine if I remove bf16 in the argument... The problem is that I need to use bf16, otherwise the inference speed is very slow..

zucchini-nlp commented 5 months ago

@edchengg

I could localize the error. When using trainer it has a line that prepares model with accelerate, which in turn adds something like model.forward = convert_outputs_to_fp32(new_forward) casting all model outputs to fp32. That is why the past _key_values cache was casted to fp32 causing errors when generating with kv cache.

@gante, I think the issue will be resolved when we get rid of the legacy cache format, so I will leave to it you here :) UPD: Or maybe not, just found this issue which causes error if cache class is used in prediction loop

ManuelFay commented 5 months ago

Same here, got it running with fp16 but bf16 raises the error on a A100 cluster and the Idefics2 Model ! Good luck with the fix !

ambroser53 commented 5 months ago

I'm also running into this error but only when quantising the model. I made an edit much like the one referenced here by @ylacombe

Hey @ArthurZucker ! I'm facing the same issue! Seems like it's coming from using DynamicCache and autocast, somewhere along the way the past key values stored in the Cache get upcasted. It's probably happening where @imirzadeh identified above.

I don't see clear way of dealing with this, as I'm not an autocast expert, but an ugly fix could be add to the condition key_values.dtype == torch.float32 here :

https://github.com/huggingface/transformers/blob/60dea593edd0b94ee15dc3917900b26e3acfbbee/src/transformers/models/mistral/modeling_mistral.py#L431-L449

Let me know what you think and how if I can help furthermore

but that simply ensures each of the three states have the same dtype:

input_dtype = query_states.dtype
if input_dtype == torch.float32 or not (query_states.dtype == key_states.dtype == value_states.dtype):
    if torch.is_autocast_enabled():
        target_dtype = torch.get_autocast_gpu_dtype()
    # Handle the case where the model is quantized
    elif hasattr(self.config, "_pre_quantization_dtype"):
        target_dtype = self.config._pre_quantization_dtype
    else:
        target_dtype = self.q_proj.weight.dtype

    logger.warning_once(
        f"The input hidden states seems to be silently casted in float32, this might be related to"
        f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
        f" {target_dtype}."
    )

    query_states = query_states.to(target_dtype)
    key_states = key_states.to(target_dtype)
    value_states = value_states.to(target_dtype)
ArthurZucker commented 5 months ago

@zucchini-nlp are we juste waiting for the deprecation of the to legacy cache func?

zucchini-nlp commented 5 months ago

@ArthurZucker sorry, totally forgot to reply here. While I was thinking that we might need to pass in "compute_dtype" in all cache classes, I found that #30536 actually should solve the problem.

In the linked PR we never go back and forth between legacy-new cache formats, thus no problems with incorrect casting to fp32. Will verify when the PR is merged, and close this issue :)

zucchini-nlp commented 4 months ago

Update: The linked PR was merged and I tested that the dtype mismatch error doesn't appear anymore when doing inference with Trainer. I used one of the above scripts for testing. All new LLMs can be used for inference with flash-attn in Trainer!

But this works only for the models that support new cache format, which excludes some relatively older models like OPT or GPT2. I will keep the issue open and extend new cache format to other decoder-only models. cc @gante, not sure if there is a tracker for making models compatible with DynamicCache 🤔

This is the script used for verifying that it works:

from datasets import load_dataset, Dataset
import torch

checkpoint = " meta-llama/Llama-2-7b-chat-hf" # facebook/opt-125M
model_kwargs = {
    "attn_implementation": "flash_attention_2",
    "torch_dtype": torch.bfloat16, 
}

from transformers import AutoTokenizer, AutoModelForCausalLM, Seq2SeqTrainingArguments

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(checkpoint, **model_kwargs, device_map="cuda")

# Create a small dataset with 5 sentences
texts = [
    "The quick brown fox jumps over the lazy dog.",
    "I love eating pizza and watching movies.",
    "The sun is shining brightly today.",
    "She plays the piano beautifully.",
    "He enjoys reading books in his free time."
]

encodings = tokenizer(texts, padding=True, return_tensors="pt")

from datasets import Dataset
# Create the dataset
dataset = Dataset.from_dict({
    "input_ids": encodings.input_ids,
    "attention_mask": encodings.attention_mask,
    "labels": encodings.input_ids
})

# Define the training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="output",
    bf16=True, # add this leads to error
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    generation_max_length=50,
    generation_num_beams=1,
)

from transformers import Seq2SeqTrainer
# Create the Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
)

# Run prediction on the dataset

test_input = "I really liked this movie because"
test_input = tokenizer(test_input, return_tensors="pt").to(model.device)
out = model.generate(test_input.input_ids, max_length=100, num_return_sequences=1)
print(tokenizer.batch_decode(out)) # Okay

predictions = trainer.predict(dataset) # Error (NOT ANYMORE)
print(tokenizer.batch_decode(predictions.predictions, skip_special_tokens=True))