Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
12.65k stars 1.13k forks source link

query and key must have the same dtype error #732

Open floyddcn opened 7 months ago

floyddcn commented 7 months ago

I'm using flash attn2 for GPTBigCodeModel model(defog/sqlcoder2) fintuning and inference.After training with lora,i encounter a error when doing inference: image

Besides,i got the same error when i do inference with the another model (https://huggingface.co/defog/sqlcoder-34b-alpha) whose raw weights is from huggingface,not finetuned by myself. And i try some methods to resolve the error ,like modify your code,add lora target for q and k,etc... Nothing works~

The original post is here :https://github.com/hiyouga/LLaMA-Factory/issues/1614 Any suggestion or solution is appreciated!

tridao commented 7 months ago

As suggested by the error message, you should make sure that q and k being passed to flash_attn functions have the same dtype (e.g. fp16 or bf16).

floyddcn commented 6 months ago

In the original post( https://github.com/hiyouga/LLaMA-Factory/issues/1614 ),i've tried to downcase to float16,but the generation result seemed to be incorrect. Should i try upcasting to float32? As i know flashattn only support fp16?

tridao commented 6 months ago

Only fp16 and bf16 are supported. I doubt it's the issue of fp16 vs fp32. You can try to print out the attention output by FlashAttention as well as attention output by a standard implementation (in fp32 or fp16) and make sure the outputs are approximately the same.

zhou6140919 commented 4 months ago

Same issue here. I tried to use flash attention when predicting outputs by Llama-2-7b-chat-hf.

    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: query and key must have the same dtype

I checked all parameters in the merged model (model + adapter), and they were all torch.float16. I printed q, k, and v dtypes just before flash_attn_cuda.varlen_fwd() function. It's weird. Where could the problem be?

q dtype: torch.float16, k dtype: torch.float16, v dtype: torch.float16