Open YangQun1 opened 1 year ago
Q1: Indeed, for inference you want to split K/V and not Q. The current code still works as is (we implicitly pad Q to have length at least 128), but it could be much better. We have something planned for this to optimize for inference, hopefully will be out soon.
Q2. We benchmarked with Q and K/V having the same length (as the case in e.g. GPT-style model training).
@tridao I see, thank you. look forward to your optimization for inference :)
Q1: ... we implicitly pad Q to have length at least 128 ...
I am wondering what's the smallest supported sequence length without padding. Maybe it can be tensor core's m size, e.g. 16 in shape(16, 16, 16)? That's because suppose there is only one warp in a thread block so for the M size, it needs at least tensor core's shape to use tensor core.
Q1: The paper mentioned that: In FlashAttention-2, we instead split Q across 4 warps. But in LLM inference scenarios, the num_queries is usually equal to 1 except the first forward. How do we split it across 4 warps? Q2: The paper benchmarked the FlashAttention-2 and other attention implementations with different seqlen. Here the seqlen means num_queries or num_keys?