Open jlamypoirier opened 6 months ago
Thanks for the report, I can reproduce it. Investigating now. Might be because of the way torch (in C++) handle dtype.
Hmm compiling from scratch seems to work fine, so sth is wrong about the wheel we built.
I'm guessing this is because 24.03 uses CUDA 12.4 and the wheels built with nvcc 12.2 are somehow not compatible.
What is the recommended fix then? Rebuild flash attention from code?
I just did a fresh install using python setup.py install
, but I still get the same error message. Either I did something wrong, or there is still a problem somewhere.
File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py", line 624, in forward
attn_function_output = self._apply_dense_attention(
File "/root/.cache/huggingface/modules/transformers_modules/microsoft/Phi-3-small-128k-instruct/f80aaa30bfc64c2b8ab214b541d9050e97163bc4/modeling_phi3_small.py", line 441, in _apply_dense_attention
attn_output_unpad = flash_attn_varlen_kvpacked_func(
File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 978, in flash_attn_varlen_kvpacked_func
return FlashAttnVarlenKVPackedFunc.apply(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 572, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 432, in forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
File "/usr/local/lib/python3.10/dist-packages/flash_attn-2.5.9.post1-py3.10-linux-aarch64.egg/flash_attn/flash_attn_interface.py", line 86, in _flash_attn_varlen_forward
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
RuntimeError: FlashAttention only support fp16 and bf16 data type
Do you think I can comment the code that does the check, and reinstall, or is that just a sign that things will break down the line?
Ok, it sounds like the issue was with the Phi3-small code, not with the library. The assert was triggering appropriately. Sorry for the noise.
Ok, it sounds like the issue was with the Phi3-small code, not with the library. The assert was triggering appropriately. Sorry for the noise.
How did you solve the issue? I am also trying to fine-tune phi3. But I receive the same error message and I don't know what to do with it tbh.
meet the same problem @tridao do you have the recommended way to solve it?
Sorry I forgot, but I think it was that you had to set a value for use rentrant other than the default.
Flash attn 2.5.7 always complains about the input data type even when it's clearly a correct one. I'm using the base image
nvcr.io/nvidia/pytorch:24.03-py3