Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.46k stars 1.36k forks source link

where is flash decoding second stage (reduce) code ? #1248

Open liuqi123123 opened 2 months ago

liuqi123123 commented 2 months ago

according to https://pytorch.org/blog/flash-decoding/ , flash decoding is dual stage, the second stage is "reduce && rescale contribution of each split", but I can't find the reduce kernel after kernel "compute_attn_1rowblock_splitkv", where is it ?

tridao commented 2 months ago

https://github.com/Dao-AILab/flash-attention/blob/53a4f341634fcbc96bb999a3c804c192ea14f2ea/csrc/flash_attn/src/flash_fwd_kernel.h#L1108

SimpleTheoryOfTypes commented 3 weeks ago

What’s the best way to trigger the flash-decoding path when using flash_fwd_splitkv_kernel(...)? Is it correct to set num_splits = 0 and let the heuristics decide automatically?

For flash-decoding, is the num_splits_heuristics function the recommended way to determine the optimal split for flash-decoding? I tried hardcoding num_splits to 2, 4, and 8, but saw worse results on an A100 (batch size 8, 48 q heads, new seqlen between 1 and 10): even though the heuristics calculate num_splits == 1 as the best choice, it seems combining multiple q heads with new seqlen in one GEMM is better at maximizing the TC utilization in my case? Thanks a lot in advance for your insights!

btw, the flash-decoding release notes mentioned a minimal example, but the link still leads to this "coming soon" page: https://github.com/Dao-AILab/flash-attention/tree/main/examples/inference. note sure if the link is still valid.

tridao commented 3 weeks ago

If num_splits = 0 we use a heuristic to decide if we should split (and how many splits). If batch = 8 and 48 q heads, there are 8 x 48 = 384 pieces of parallel work, more than the number of SMs on A100 (108). So there's no reason to split. Usually split is needed if there's not enough parallel work to assign to all the SMs.

SimpleTheoryOfTypes commented 3 weeks ago

Is it feasible to vectorize the S=QK^T and SV GEMMs along the batch dimension in flash decoding? For example, during decoding, the query q has a shape of [b, 1, 48, 128], and a KV tile has a shape of [b, 64, 48, 128], where 48 represents the number of attention heads, 128 is the head dimension, and 64 is kBlockN, with only 1 token being decoded at a time.

In the current implementation, a separate flash::gemm is run for each batch, resulting in GEMM shapes like:

q: 1 x 128 K: 64 x 128

Here, flash::gemm computes a q: kBlockM=64 x 128 by K: kBlockN=64 x 128 GEMM, but only the first row of the result is used, while the remaining 63 rows are discarded (partition utilization 1/ 64 = 1.6%).

Would it be feasible to perform the following using a single flash::gemm for the entire batch in flash decoding, i.e.,:

q: (b, 1) x 128 K: (b, 64) x 128

This approach could potentially improve tensor core partition utilization by a factor of b, as it would allow us to keep b rows of the output tensor instead of just 1.

If flash decoding can already do this batch dim vectorization , how to enable it? Thanks!

tridao commented 3 weeks ago

No that doesn't increase tensor core util. The operation is mem bound any way (you can measure that its speed is close to memcpy) so it doens't matter that we're doing extra compute

SimpleTheoryOfTypes commented 1 week ago

Given that, with small batch sizes, the attention kernel during decoding is memory-bound, why would maximizing SM utilization by creating more parallel work along the sequence dimension still lead to improved latency?

flash decoding does help a lot with small batch size (<5) decoding, just wanted to verify my understanding: Flash decoding’s main optimization appears to be optimizing compute unit utilization, which seems at odds with the fact that the attention kernel during decoding is memory-bound. Is it because the scheduling is suboptimal?

tridao commented 1 week ago

Mem bound here means most of the time is spent waiting for memory to be loaded from global memory. You want more thread blocks issuing load instructions. If batch size = 1, seqlen_q = 1, nheads = 16, then you only have 16 thread blocks issuing loads (out of 108 or 132 SMs). So you're not saturating memory bandwidth. You want to parallelize along the seqlen_k dimension so that more thread blocks are issuing loads and saturate mem bw

SimpleTheoryOfTypes commented 1 week ago

Thanks a lot for the explanation! that makes sense, flash decoding also optimizes memory bandwidth by creating more parallel LD/ST instructions.