Open floyddcn opened 7 months ago
As suggested by the error message, you should make sure that q and k being passed to flash_attn functions have the same dtype (e.g. fp16 or bf16).
In the original post( https://github.com/hiyouga/LLaMA-Factory/issues/1614 ),i've tried to downcase to float16,but the generation result seemed to be incorrect. Should i try upcasting to float32? As i know flashattn only support fp16?
Only fp16 and bf16 are supported. I doubt it's the issue of fp16 vs fp32. You can try to print out the attention output by FlashAttention as well as attention output by a standard implementation (in fp32 or fp16) and make sure the outputs are approximately the same.
Same issue here. I tried to use flash attention when predicting outputs by Llama-2-7b-chat-hf
.
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: query and key must have the same dtype
I checked all parameters in the merged model (model + adapter), and they were all torch.float16. I printed q, k, and v dtypes just before flash_attn_cuda.varlen_fwd() function. It's weird. Where could the problem be?
q dtype: torch.float16, k dtype: torch.float16, v dtype: torch.float16
I'm using flash attn2 for GPTBigCodeModel model(defog/sqlcoder2) fintuning and inference.After training with lora,i encounter a error when doing inference:![image](https://github.com/Dao-AILab/flash-attention/assets/48442748/266fa7bf-8b1c-4d58-a9bd-0a44ec9cce8a)
Besides,i got the same error when i do inference with the another model (https://huggingface.co/defog/sqlcoder-34b-alpha) whose raw weights is from huggingface,not finetuned by myself. And i try some methods to resolve the error ,like modify your code,add lora target for q and k,etc... Nothing works~
The original post is here :https://github.com/hiyouga/LLaMA-Factory/issues/1614 Any suggestion or solution is appreciated!