jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.42k stars 2.79k forks source link

[Pallas][Flash Attention] Cost analysis in Flash Attention kernel causing compilation OOM #24539

Open lsy323 opened 1 week ago

lsy323 commented 1 week ago

Description

With 0.4.35 release, the flash attention kernel will hit compilation OOM for long sequence length inputs. It failed during compiling the reference attention implementation for cost analysis purpose, added in this commit: https://github.com/jax-ml/jax/commit/4c218fbf3b8431a5f75cdf20942d5d62433a8657.

This logic would cause OOM when the device HBM is not enough for the naive attention but fits for flash attention kernels.

Here is a small repro script for this issue (on v5e TPU with 16GB HBM)

import jax
import jax.numpy as jnp

from jax.experimental.pallas.ops.tpu import flash_attention

kv_seq_len = 16 * 1024
batch_size = 18
n_heads = 32
head_dim = 256

q_seq_len = kv_seq_len
kv_shape = (batch_size, n_heads, kv_seq_len, head_dim)
q_shape = (batch_size, n_heads, q_seq_len, head_dim)
dtype = jnp.bfloat16

q = jnp.ones(q_shape, dtype=dtype)
k = jnp.ones(kv_shape, dtype=dtype)
v = jnp.ones(kv_shape, dtype=dtype)

@jax.jit
def run_flash_attention(q, k, v):
    return flash_attention.flash_attention(q, k, v)

compiled = jax.jit(run_flash_attention).lower(q, k, v).compile()

The error message is:

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/lsiyuan/jax-flashattn-oom/repro.py", line 24, in <module>
    compiled = jax.jit(run_flash_attention).lower(q, k, v).compile()
  File "/home/lsiyuan/jax-flashattn-oom/repro.py", line 22, in run_flash_attention
    return flash_attention.flash_attention(q, k, v)
  File "/home/lsiyuan/miniconda3/envs/torch310/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 199, in flash_attention
    return _flash_attention(
  File "/home/lsiyuan/miniconda3/envs/torch310/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 217, in _flash_attention
    return _flash_attention_impl(
  File "/home/lsiyuan/miniconda3/envs/torch310/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 791, in _flash_attention_impl
    cost_estimate=_fwd_cost_estimate(
  File "/home/lsiyuan/miniconda3/envs/torch310/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/flash_attention.py", line 589, in _fwd_cost_estimate
    .compile()
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 22.50G of 15.75G hbm. Exceeded hbm capacity by 6.75G.

Total hbm usage >= 22.75G:
    reserved        258.00M 
    program           9.00G 
    arguments        13.50G 

Output size 4.50G; shares 0B with arguments.

Program hbm requirement 9.00G:
    HLO temp          9.00G (100.0% utilization: Unpadded (9.00G) Padded (9.00G), 0.0% fragmentation (0B))

  Largest program allocations in hbm:

  1. Size: 9.00G
     Shape: f32[18,32,16384,256]{3,2,1,0:T(8,128)}
     Unpadded size: 9.00G
     XLA label: name = custom-call(Arg_0.1, Arg_1.2, Arg_2.3), custom_call_target="tpu_custom_call", operand_layout_constraints={bf16[18,32,16384,256]{3,2,1,0}, bf16[18,32,16384,256]{3,2,1,0}, bf16[18,32,16384,256]{3,2,1,0}}
     Allocation type: HLO temp
     ==========================

It failed with 0.4.35 but passed with 20240829 nightly

cc @WoosukKwon who surfaced and root caused the issue in the first place.

System info (python version, jaxlib version, accelerator, etc.)

Version with the compilation OOM issue

jax:    0.4.35
jaxlib: 0.4.35
numpy:  1.26.4
python: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]
device info: TPU v5 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-81baa7fa-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')

Version without the compilation OOM issue

jax:    0.4.32.dev20240829
jaxlib: 0.4.32.dev20240829
numpy:  1.26.4
python: 3.10.14
device info: TPU v5 lite-8, 8 local devices
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-81baa7fa-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
justinjfu commented 1 week ago

https://github.com/jax-ml/jax/pull/24608 should fix this by calculating the cost estimate directly instead of via compiling a reference function. .

In the meantime you could also just remove the cost-estimate temporarily as it won't affect correctness of the code.