Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.26k stars 1.34k forks source link

Some questions about the flash-attention 2 paper #346

Open YangQun1 opened 1 year ago

YangQun1 commented 1 year ago

Q1: The paper mentioned that: In FlashAttention-2, we instead split Q across 4 warps. But in LLM inference scenarios, the num_queries is usually equal to 1 except the first forward. How do we split it across 4 warps? Q2: The paper benchmarked the FlashAttention-2 and other attention implementations with different seqlen. Here the seqlen means num_queries or num_keys?

tridao commented 1 year ago

Q1: Indeed, for inference you want to split K/V and not Q. The current code still works as is (we implicitly pad Q to have length at least 128), but it could be much better. We have something planned for this to optimize for inference, hopefully will be out soon.

Q2. We benchmarked with Q and K/V having the same length (as the case in e.g. GPT-style model training).

YangQun1 commented 1 year ago

@tridao I see, thank you. look forward to your optimization for inference :)

WindowsXp-Beta commented 1 year ago

Q1: ... we implicitly pad Q to have length at least 128 ...

I am wondering what's the smallest supported sequence length without padding. Maybe it can be tensor core's m size, e.g. 16 in shape(16, 16, 16)? That's because suppose there is only one warp in a thread block so for the M size, it needs at least tensor core's shape to use tensor core.