Open Andcircle opened 9 months ago
I'm not familiar with accelerate or how transformers
uses FlashAttention, you'd probably get better help asking on those repos.
I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue.
Reproducer is here with attn_implementation="flash_attention_2"
and the corresponding PR on transformers
.
- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.27.0
- Accelerate config: not found
- PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>
flash_attn=2.5.3
+ torch nightly so (2.3 ish)>>> from flash_attn import flash_attn_func
>>> import torch
>>> print(torch.__version__)
2.3.0.dev20240208+cu121
>>> flash_attn_func(torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), torch.ones((2,3), dtype=torch.bfloat16), 1, softmax_scale=1, causal=True)
....
File ~/miniconda3/envs/py310/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:51, in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
52 q,
53 k,
54 v,
55 None,
56 alibi_slopes,
57 dropout_p,
58 softmax_scale,
59 causal,
60 window_size[0],
61 window_size[1],
62 return_softmax,
63 None,
64 )
65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
RuntimeError: FlashAttention only support fp16 and bf16 data type
this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is
this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is
The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).
The error is before that, but it seems it's torch nightly, the transformers
snippet works with torch2.2 ! (vs getting the FlashAttention only support fp16 and bf16 data type
with nightly)
So more reliable.
(I am getting RuntimeError: q must be on CUDA
with my snippet on torch2.2 so different error)
I am getting a similar issue without training with torch nightly on Llama so can confirm something's wrong! Might be on our side, but as far as I tested all the inputs's dtypes were bfloat16, still got the issue. Reproducer is here with
attn_implementation="flash_attention_2"
and the corresponding PR ontransformers
.- `transformers` version: 4.38.0.dev0 - Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31 - Python version: 3.10.0 - Huggingface_hub version: 0.20.3 - Safetensors version: 0.4.2 - Accelerate version: 0.27.0 - Accelerate config: not found - PyTorch version (GPU?): 2.3.0.dev20240208+cu121 (True) - Tensorflow version (GPU?): not installed (NA) - Flax version (CPU?/GPU?/TPU?): not installed (NA) - Jax version: not installed - JaxLib version: not installed - Using GPU in script?: <fill in> - Using distributed or parallel set-up in script?: <fill in>
flash_attn=2.5.3
+ torch nightly so (2.3 ish)
I can't run the reproducer right now bc StaticCache is not in transformers 4.37.2 (latest stable version).
this doesn't work for me again, might be because I have. cc @tridao not sure how relevant this is
The q, k, v need to be on 'cuda' and have shape (batch, seqlen, nheads, headdim).
Yeah flash attention uses (batch , seqlen, nheads, headdim ) to represent inputs, however in many software (triton, for example) we have reasons to use (batch, nheads, seqlen, headim) for easy arrangement of layout.
Actually they are equivalent with this mapping:
def permute(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (self.nheads, self.headim)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
But it is weird that the error (I have tested in the lastest version) says "FlashAttention only support fp16 and bf16 data type".
# mha_fwd https://github.com/Dao-AILab/flash-attention/blob/6bbc532388e61185a92e2a563126739967b4c8c5/csrc/flash_attn/flash_api.cpp#L339-L339
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
// We will support Turing in the near future
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
}
I have checked the repo, we need to update our C++ templates to support various dtype, I have experiences in near memory chip op libs. Currently I have to do these unnecessary cast to help teams to use flash attention v2:
if q.dtype == torch.float32:
q = q.to(torch.float16, non_blocking=True)
k = k.to(torch.float16, non_blocking=True)
v = v.to(torch.float16, non_blocking=True)
elif q.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz):
capability = torch.cuda.get_device_capability()
if capability[0] <= 8:
raise RuntimeError("Flash attention for FP8 (need hoper TE support) is currently only supported for compute capability >= 80")
else:
# TODO (yiakwy) : add FP8 support
raise NotImplemented
output = flash_attn_func(q, k, v, dropout_p=self.dropout.p, causal=is_causal)
output = revert_mold_flash_attn_input(output)
if output_attentions:
raise Exception("Does not support output attention weights inside flash attention.")
if output.dtype != torch.float32:
# TODO (yiakwy) : add support of fp16 and bf16
# if output dtype is not FP32 (by default Flash attetnion generate FP16 output), we need to cast it back
output = output.to(torch.float32, non_blocking=True)
So we need to update the error information, right ?
I confirm that flash-attn==2.5.6 doesn't work with torch==2.3.0a0+40ec155e58.nv24.3 nightly even though inputs are indeed in torch.bfloat16 format! I rolled back to torch2.2 stable and reinstalled flash-attn and now it works.
System Info
Reproduction
The following script works as expected on 1 GPU, but if running on multiple GPU with DP, it will give error: out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( RuntimeError: FlashAttention only support fp16 and bf16 data type