Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.01k stars 1.31k forks source link

Potential algorithmic improvements for backward pass #1172

Open philipturner opened 2 months ago

philipturner commented 2 months ago

I removed the FP32 atomics, allowing the dQ accumulation to be done more naturally (a second backward kernel traveling in the same direction as forward). This looks easier to code, and was faster in production. Despite having more computations (9 GEMMs in FWD + BWD, the order work is dispatched during model training). Hardware these days likes higher arithmetic intensity.

Also searched a giant combinatorial space of block dimensions, and whether each operand is cached or spilled. It's more complex than GEMM with only A/B/C matrices. ~10 different operands (Q, K, V, O, log-sum-exp, dO, dV, dK, dQ). I had to battle the compiler, which wanted to use up all the registers and make occupancy plummet. A lot of competing variables to consider. Therefore, I treated the register pressure issue as a combinatorial problem.


Screenshot 2024-08-23 at 6 38 05 PM

Could you double-check my reasoning and perhaps fill in missing data for NV hardware? I want a second opinion, from someone with similar expertise regarding the FlashAttention algorithm.

https://github.com/philipturner/metal-flash-attention

tridao commented 2 months ago

This is great. Yes the backward pass is a massive pain. Having 2 separate kernels for dK, dV and then dQ also makes sense. That's how the Triton implementation does it. https://github.com/triton-lang/triton/blob/e0613c64b9bb872e4fa47dbce15ffa6770a1032b/python/tutorials/06-fused-attention.py#L309

The utilization mostly depends on the ratio of matmul FLOPS (e.g. tensor cores) and vector FLOPS (e.g. exponential for softmax). The higher this ratio, the harder it is to get high util. Idk much about Apple chips, but going from A100 -> H100, the matmul FLOPS double (per streaming multiprocessor per clock cycle), but FLOPS for exponential stays the same (per SM per clock cycle). As an example, A100 has 312 TLFOPS of FP16 matmul and 2.44 TFLOPS for exponential. The H100 has 989 TFLOPS of FP16 matmul but only 3.86 TFLOPS for exponential.

Do you know the numbers for Apple chips (I couldn't find it from a quick look)?

philipturner commented 2 months ago

You can issue 32 complex instructions per core-cycle on Apple chips. In comparison, you can issue 128 FMA instructions per core-cycle on Apple chips. If you issue a mixed workload where 3 of every 4 instructions is FMA, then both pipelines can be saturated simultaneously. Here are numbers across numerous architectures, for consumer GPUs. All of them appear to have 128 ALUs in the core, for the main FFMA32/IADD32 pipeline.

This table was the result of extensive research across numerous sources, from Chips and Cheese benchmarks to Intel's hardware documentation. Normalizing for different vendors reporting "what is a GPU core"? For example, AMD's dual compute unit (WGP) should be reported as two GPU cores. It is similar in memory I/O bus width to that of two CPU cores. What I consider "GPU core" has the same memory I/O bus width as an Apple Firestorm CPU or Intel AVX-2 CPU.

https://github.com/philipturner/metal-benchmarks?tab=readme-ov-file#operations-per-second

Architecture EXP2 IMUL32 IADD32 FMUL32 FFMA32 FFMA16 (Tensor)
Apple (M1, M3) 32 32 128 128 128 128
AMD (RDNA 3) 32 16 64 128 128 256
Intel (Gen9) TBD TBD TBD 64 64 128
Nvidia (Ada) 16 64 64 128 128 256

The 16-bit FFMAs are obfuscated in hardware-specific SIMD group matrix instructions. But there are dedicated ALUs on certain chips, with double the throughput of FFMA32. These happen to be Nvidia's tensor cores on consumer chips. To make an accurate analysis, I'd have to map this model of instructions per GPU core-cycle to Nvidia server hardware.

philipturner commented 2 months ago

A few insights after thinking some more. The way you do dQ accumulation is very restrictive. It encourages square block sizes, to balance the bandwidth of paging/reading along two orthogonal matrix dimensions. In comparison, my approach makes no distinction between the dimensions. There is always just one where the paging overhead is amortized ("traversal"). That opens up a much larger combinatorial space for block sizes. There is a higher opportunity you will find a kernel with high performance, with my approach. Try 32-64 x 160-320.

Second, each backward kernel has a higher ratio of FFMA instructions to complex instructions, than forward. For example, all kernels have 5 softmax instructions, one of which is exponentiation. There is no reason that bottleneck should prevent these kernels from having higher ALU utilization. Only the fact that we're issuing more compute work, and assuming your 7 GEMMs have 100% utilization (which they do not, by a long shot). Perhaps the instructions that read / page stuff to memory count as "vector instructions". They could be a new bottleneck.

tridao commented 2 months ago

That's an interesting perspective. Both approaches have tradeoffs (more recompute vs more memory access).

I did think about this approach (7 GEMMs) when we rewrote for Hopper. The 5 GEMMs approach ended up at around 550 TFLOPS (for hdim 128) for the backward pass alone. For the 7 GEMMs approach to get the same speed, the GEMMs would have to be at 770 TFLOPS (since it does 7/5 = 1.4x more compute). It would be hard to get to 770 TFLOPS (the forward with 2 GEMMs can get to around 700 TFLOPS, and pure matmul with 1 GEMM can get to 750-770 TFLOPS).

For a different chip, with different ratio of compute / memory, the tradeoff could look different.

philipturner commented 2 months ago

How does it compare with super large head dimensions? I don't even see benchmarks for D=256. This implies that training (gradient computation) cannot be evaluated with FlashAttention. Because the accumulator consumes too many registers and the compiler spills them to RAM. People are shifting to models with larger head dimension nowadays.

tridao commented 2 months ago

The numbers for D=256 aren't there yet (as you said, registers is a big headache). Most models use hdim 128.

philipturner commented 2 months ago

There are some important models (FLUX I believe?) which have hdim 256. And on Apple hardware, with less registers, you have issues at hdim 128 that Nvidia would face at hdim 256.

At large head dimensions, you need to intentionally spill the accumulator to RAM. This shifts where the memory bottlenecks are in the kernel. Encourages highly non-square, lopsided blocks. Such as 64x128 -> 32x256. The major issue of atomic dQ, is it forces you to use square block sizes. So perhaps it reaches 77% ALU utilization, and 7/5x more computations. 7 GEMMs may be be slower or faster than the implementation currently here. But it does open up the flexibility to optimize for more variation in head dimension.

tridao commented 2 months ago

Btw I'm not sure doing dQ in a separate kernel will save registers. In the kernel with 4 GEMMs, we'd still need registers to hold the accumulator for all the 4 GEMMs. For the 5th gemm (dQ), the accumulator for dQ can use the same registers as those of QK^T and dO V^T accumulators.

philipturner commented 2 months ago

Does the FlashAttention repo, as is, have a block parameter you can just tweak so it reads 256 elements in memory to support D=256? Or is the head dimension hardcoded (e.g. if I wanted a model with D=127 and the elements are laid out in memory 127 elements apart, I could not do that).

tridao commented 2 months ago

Head dimensions are hard-coded. Just like gemm tile sizes are hard-coded (usually 128 x 128 or 128 x 256 or 256 x 128).

philipturner commented 2 months ago

Btw I'm not sure doing dQ in a separate kernel will save registers. In the kernel with 4 GEMMs, we'd still need registers to hold the accumulator for all the 4 GEMMs. For the 5th gemm (dQ), the accumulator for dQ can use the same registers as those of QK^T and dO V^T accumulators.

You can reduce the register allocation to nil. With this approach, both separation into two different kernels, and lopsided block sizes aligned to the traversal dimension. When the head dimension grows infinitely large.

tridao commented 2 months ago

For Hopper at least some of the block sizes need to be divisible by 64 (since the wgmma instruction needs M=64). So you can't really decrease M that much (we already use tile size 64 x 128 for hdim 128 bwd).

philipturner commented 2 months ago

Head dimensions are hard-coded. Just like gemm tile sizes are hard-coded (usually 128 x 128 or 128 x 256 or 256 x 128)

I need a little more flexibility than that. From the start, I've targeted wierd and typically not benchmarked problems. Such as D=40 in Stable Diffusion v1, which was a factor of 4 slower than D=32 in Metal Performance Shaders. The code generation handles edge cases where D or sequence dimension doesn't align with block size. For example D=199 where D_block=32.

DanFu09 commented 2 months ago

Just popping in quickly here, Flux has head dimension 128 :)

On Fri, Aug 23, 2024 at 10:49 PM Philip Turner @.***> wrote:

Head dimensions are hard-coded. Just like gemm tile sizes are hard-coded (usually 128 x 128 or 128 x 256 or 256 x 128)

I need a little more flexibility than that. From the start, I've targeted wierd and typically not benchmarked problems. Such as D=40 in Stable Diffusion v1, which was a factor of 4 slower than D=32 in Metal Performance Shaders. The code generation handles edge cases where D or sequence dimension doesn't align with block size. For example D=199 where D_block=32.

— Reply to this email directly, view it on GitHub https://github.com/Dao-AILab/flash-attention/issues/1172#issuecomment-2308141099, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABDDIIXDDYIOAR2H3E4LND3ZTANEXAVCNFSM6AAAAABNA62TDKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDGMBYGE2DCMBZHE . You are receiving this because you are subscribed to this thread.Message ID: @.***>

philipturner commented 2 months ago

For Hopper at least some of the block sizes need to be divisible by 64 (since the wgmma instruction needs M=64). So you can't really decrease M that much (we already use tile size 64 x 128 for hdim 128 bwd).

On Apple, the SIMD group matrix instruction is scoped at M=8. So I can access block sizes such as 48x48 (the optimal one for GEMM, which is strangely not a power of 2). And 32x80, is the best for attention on M1 architecture. Also sometimes played around with the D block dimension being 24.

philipturner commented 2 months ago

Just popping in quickly here, Flux has head dimension 128 :)

Must be another model. But it's important to have code that can work with large head dimensions. @liuliu perhaps can say which model he's targeting with D=256.

philipturner commented 2 months ago

Would this test even compile with the FlashAttention repo? Forget executing fast, would I even get correct results running this test suite on a CUDA device?

https://github.com/philipturner/metal-flash-attention/blob/749e742ac27d01a63f88b9e25a566e70dfac5632/Tests/FlashAttentionTests/Attention/SquareAttentionTest.swift#L5-L26

tridao commented 2 months ago

It would run and get correct results. For hdim not divisible by 8 we explicitly pad (which would make it slower). This is mostly so that we can use TMA (which requires alignment by 16 bytes). If one really wants, one can forgo TMA and use standard copying, but that's more code to write, so we don't do that.

For hdim not multiples of 32 or 64, we implicitly pad. There are kernels for hdim 64, 96, 128, 192, 256. Any hdim not in those 5 will get rounded.

philipturner commented 2 months ago

Apple does implicit padding for a base alignment of 2 bytes. The SIMD async copy instruction can pad matrix dimensions up to multiples of 8. From your explanation, the TMA hardware would not be compatible with D=127? The client would have to make a second memory allocation, with row-stride D=128, with the extra element zero-padded. Then, enter that allocation into FlashAttention. Seems like twice as much memory consumption as is necessary.

There are kernels for hdim 64, 96, 128, 192, 256.

Does this mean it’s technically possible to get benchmarks for A100/H100 training performance with D=256? Specifically, these missing datapoints.


Screenshot 2024-08-24 at 10 08 41 AM
tridao commented 2 months ago

TMA requires addresses to be multiples of 16 bytes, so if you have D=127 TMA won't work. This is standard in GEMM as well, where e.g. Nvidia recommends that dimensions be multiples of 16 bytes (https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc). There's limited time for open-source work, and we've focused more on common dimensions that most models use.

Yes FA2 hdim 256 is supported. FA3 has fwd for hdim 256 but bwd for hdim 256 is not out yet.

Ultimately the choice of 5 GEMM or 7 GEMM bwd is empirical: the right thing is to implement both and pick what's faster. For hdim128 on H100 I'm convinced that 5 GEMM is faster (by the reasoning above where the 7 GEMM approach would have to get to 770 TFLOPS, which is already not easy for 1 GEMM). For different hdim, different chip, I could see 7 GEMM being faster.

Note that the 7 GEMM approach has the advantage that it gives deterministic results. That might be one reason the Triton implementation went with it.

philipturner commented 2 months ago

This reasoning makes sense. There are not a lot of models out there with hdim 256. On Apple hardware, I had less register memory, and FP32 atomics were not supported in hardware. The 5 GEMM approach would be extremely complex to implement, and unnecessarily constrained the block sizes to be squares. I would love to see benchmarks of A100 / H100 doing FlashAttention with hdim 256, just to see how well my implementation compares to yours.

I designed my codebase to use code generation and benchmarking across the entire spectrum of problem sizes. It requires little effort for me to test hdim 256, 257, 384, 385, etc. However, it seems the main repo hard-codes the head dimensions, and changing hdim requires a lot of effort. This is a common practice, for example PyTorch FlexAttention only supports matrix dimensions divisible by 128. It is much easier to write kernels restricted to integer powers of 2 (they can be handwritten). The odd / misaligned matrices are more challenging, that's why I needed code generation.

I'll have to leave this investigation unanswered for the time being, move on to other things. Thanks for the insights though.

liuliu commented 2 months ago

There are some important models (FLUX I believe?) which have hdim 256. And on Apple hardware, with less registers, you have issues at hdim 128 that Nvidia would face at hdim 256.

Just to reference this, I meant the AuraFlow model which uses hdim 256 (and trained on H100). FLUX.1 uses hdim 128.

philipturner commented 2 months ago

I changed the numbers to say 0% for the missing data points, because D=256 gradient cannot even be evaluated in the main repository. I wonder how the AuraFlow model was trained.

tridao commented 2 months ago

No FA2 has bwd for D=256.

https://github.com/Dao-AILab/flash-attention/blob/32792d37ec66902e5d82e149971daacbee8b55d7/csrc/flash_attn/src/flash_bwd_launch_template.h#L301

philipturner commented 2 months ago

Would you be able to execute backward for D=256 on H100 or A100? I don't have access to one of those GPUs. If it's only for A100, I'm okay with that data point.

tridao commented 2 months ago

Yes it runs on both A100 and H100 (not using new H100 features). Util on A100 is around the same as hdim64. Util on H100 is not great (obv), like 250-270 TFLOPS.

philipturner commented 2 months ago

Would you mind providing the specific and quantitative benchmark data? As in exact GFLOPS at N=16384, up to four significant figures.

tridao commented 2 months ago

Oh these are just numbers off the top of my head. Would need to find an A100 :D I mostly work w H100.

philipturner commented 2 months ago

I’ll take whatever I can get. I want to examine how Flash2 scales with increasing head dimension on the H100. That’s sufficient to get the conclusion I’m after.

philipturner commented 2 months ago

This thing, right here, but on the H100 + Flash2 instead of the M1 Max + MFA. This code issues 5 attentions sequentially, just to gather a single data point. It then does that procedure 5 times, for a total of 25 attentions computed. Perhaps you'll need to batch a bunch of attentions in parallel to fully saturate the parallelism of H100. I think that's what the bh or bq parameters (head count, batch count) do, except interleaving different elements of the batch in memory, confusing things.

  // Benchmark performance.
  var maxGINSTRS: Int = .zero
  for _ in 0..<5 {
    let dispatchCount: Int = 5
    let latencySeconds = executeCommandBuffer(dispatchCount: dispatchCount)

    // Determine the amount of work done.
    //
    // WARNING: Change this code to match the kernel you're profiling.
    var operations: Int
    switch benchmarkedKernel {
    case .forward:
      operations = 2 * headDimension + 5
    case .backwardQuery:
      operations = 3 * headDimension + 5
    case .backwardKeyValue:
      operations = 4 * headDimension + 5
    }
    operations *= (sequenceDimension * sequenceDimension)
    operations *= dispatchCount

    // Divide the work by the latency, resulting in throughput.
    let instrs = Double(operations) / Double(latencySeconds)
    let ginstrs = Int(instrs / 1e9)

    // Accumulate the sample from this trial.
    maxGINSTRS = max(maxGINSTRS, ginstrs)
  }
  return maxGINSTRS

Then, automate a test over 38 points in the spectrum of D dimensions.

    //    var D_array: [Int] = []
    //
    //    var D_cursor = 0
    //    while D_cursor < 96 {
    //      D_cursor += 4
    //      D_array.append(D_cursor)
    //    }
    //    while D_cursor < 160 {
    //      D_cursor += 8
    //      D_array.append(D_cursor)
    //    }
    //    while D_cursor < 256 {
    //      D_cursor += 16
    //      D_array.append(D_cursor)
    //    }
    //    while D_cursor < 384 {
    //      D_cursor += 32
    //      D_array.append(D_cursor)
    //    }

And finally, a beautiful graph like this. Modified to suit the differences between the MFA and Flash2 repos.

Screenshot 2024-08-24 at 9 09 21 PM