Optimize TPU Flash Attention (400x speed-up on 32k long context)
Use splash attention lazy mask instead of jnp mask, which is O(T^2).
The memory for jnp mask is O(T^2), which almost negates the benefits of
reducing HBM communication with flash attention. Let’s use splash attention
lazy mask, which lazily generates causal masks.
In addition, pallas supports CPU simulation (interpret=True), so use same
pallas kernel on CPU, which makes it easier to debug the code.
Benchmark: on TPUv5p, (model_dim/heads/kv_heads/seq_len).
NumpyMask (ASIS)
----------------------------------------------------------------------------------------
Benchmark Time CPU Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512 1.71 ms 1.09 ms 592 (4.43M)
FlashAttentionBenchmark/2048/16/2/1024 4.44 ms 1.21 ms 483 (28.62M)
FlashAttentionBenchmark/4096/16/2/1024 8.61 ms 1.36 ms 302 (53.27M)
FlashAttentionBenchmark/4096/16/2/4096 3264 ms 1537 ms 1 (197.38M)
FlashAttentionBenchmark/4096/16/2/8192 7426 ms 5603 ms 1 (389.54M)
FlashAttentionBenchmark/4096/16/2/32768 94427 ms 92256 ms 1 (1.50G)
CausalMask (Proposed PR): This PR saves both memory and computation. In long
context, speed-up (400x) and HBM saving (3x).
----------------------------------------------------------------------------------------
Benchmark Time CPU Iterations
----------------------------------------------------------------------------------------
FlashAttentionBenchmark/256/2/2/512 1.55 ms 1.01 ms 578 (3.43M)
FlashAttentionBenchmark/2048/16/2/1024 4.21 ms 1.11 ms 490 (13.57M)
FlashAttentionBenchmark/4096/16/2/1024 6.50 ms 1.17 ms 493 (24.22M)
FlashAttentionBenchmark/4096/16/2/4096 16.8 ms 1.38 ms 228 (84.33M)
FlashAttentionBenchmark/4096/16/2/8192 28.8 ms 1.58 ms 217 (164.50M)
FlashAttentionBenchmark/4096/16/2/32768 230 ms 6.36 ms 16 (644.60M)
Optimize TPU Flash Attention (400x speed-up on 32k long context)
Use splash attention lazy mask instead of jnp mask, which is O(T^2).
The memory for jnp mask is O(T^2), which almost negates the benefits of reducing HBM communication with flash attention. Let’s use splash attention lazy mask, which lazily generates causal masks.
In addition, pallas supports CPU simulation (interpret=True), so use same pallas kernel on CPU, which makes it easier to debug the code.
NumpyMask (ASIS)
CausalMask (Proposed PR): This PR saves both memory and computation. In long context, speed-up (400x) and HBM saving (3x).