Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.37k stars 1.22k forks source link

Correctness of `flash_attn_varlen_func` kernel with cuda graph. #1164

Open LiuXiaoxuanPKU opened 3 weeks ago

LiuXiaoxuanPKU commented 3 weeks ago

Thanks for the great repo! We are testing the correctness of flash_attn_varlen_func when enabling the cuda graph. This is the test we use. https://github.com/vllm-project/vllm/blob/66e832be41cd3f29bd2b37303ea5944efcb16204/tests/kernels/test_flash_attn.py#L234 We found that the value of context length will affect the correctness, even if the shapes of all input parameters are the same. This is required because we don't know the context length during the graph capture time. Any hints on solving the problem is highly appreciated.

tridao commented 3 weeks ago

What's "context length" here? Which variable?

LiuXiaoxuanPKU commented 3 weeks ago

It will affect the value max_seqlen_k. https://github.com/vllm-project/vllm/blob/66e832be41cd3f29bd2b37303ea5944efcb16204/tests/kernels/test_flash_attn.py#L258

tridao commented 3 weeks ago

max_seqlen_k is a variable on CPU. After the kernel is capture, changing this value will have no effect.

tridao commented 3 weeks ago

It's similar to other variables on CPU, such as softmax_scale. If the kernel is captured with softmax_scale = 1.0, then after that if you change softmax_scale to 2.0 and replay the kernel, it would work as if softmax_scale=1.0.

LiuXiaoxuanPKU commented 3 weeks ago

Thanks! Then it's a bit wire. We did check the input shape of all other variables (q, k, v, cu_seqlens_q, cu_seqlens_k, and block_table). What's the bets way to debug it?

tridao commented 3 weeks ago

You're trying to change a CPU variable after capturing CUDA graph, that's not supported by CUDA graph. I haven't looked closely but looks like in this case the kernel is behaving as expected. Can you describe what behavior you expect?