apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.88k stars 269 forks source link

Optimize TPU Flash Attention (400x speed-up on 32k long context) #845

Open ds-hwang opened 1 week ago

ds-hwang commented 1 week ago

Optimize TPU Flash Attention (400x speed-up on 32k long context)

Use splash attention lazy mask instead of jnp mask, which is O(T^2).

The memory for jnp mask is O(T^2), which almost negates the benefits of reducing HBM communication with flash attention. Let’s use splash attention lazy mask, which lazily generates causal masks.

In addition, pallas supports CPU simulation (interpret=True), so use same pallas kernel on CPU, which makes it easier to debug the code.

NumpyMask (ASIS)

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           1.71 ms         1.09 ms          592   (4.43M)
FlashAttentionBenchmark/2048/16/2/1024        4.44 ms         1.21 ms          483  (28.62M)
FlashAttentionBenchmark/4096/16/2/1024        8.61 ms         1.36 ms          302  (53.27M)
FlashAttentionBenchmark/4096/16/2/4096        3264 ms         1537 ms            1 (197.38M)
FlashAttentionBenchmark/4096/16/2/8192        7426 ms         5603 ms            1 (389.54M)
FlashAttentionBenchmark/4096/16/2/32768      94427 ms        92256 ms            1   (1.50G)

CausalMask (Proposed PR): This PR saves both memory and computation. In long context, speed-up (400x) and HBM saving (3x).

----------------------------------------------------------------------------------------
Benchmark                                              Time             CPU   Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512           1.55 ms         1.01 ms          578   (3.43M)
FlashAttentionBenchmark/2048/16/2/1024        4.21 ms         1.11 ms          490  (13.57M)
FlashAttentionBenchmark/4096/16/2/1024        6.50 ms         1.17 ms          493  (24.22M)
FlashAttentionBenchmark/4096/16/2/4096        16.8 ms         1.38 ms          228  (84.33M)
FlashAttentionBenchmark/4096/16/2/8192        28.8 ms         1.58 ms          217 (164.50M)
FlashAttentionBenchmark/4096/16/2/32768        230 ms         6.36 ms           16 (644.60M)