google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

Improper `ComputeCapability` check for `cudnn_dot_product_attention`. #22546

Closed MasterSkepticista closed 1 month ago

MasterSkepticista commented 1 month ago

Description

We check for CC here: https://github.com/google/jax/blob/9632a2d1a86496cb1bca7bacdacef3bf554b5153/jax/_src/cudnn/fused_attention_stablehlo.py#L990

But the check (L316) fails if compute_cap is an integer between (80, 90). https://github.com/google/jax/blob/9632a2d1a86496cb1bca7bacdacef3bf554b5153/jax/_src/cudnn/fused_attention_stablehlo.py#L315-L317

The intention is to allow all GPUs with compute capability within the range.

if compute_cap not in cc:  # (86 not in (80, 90)) will fail. It shouldn't.
  raise RuntimeError(...)

I disabled the CC check on my platform (it is 86) and cudnn gives speedup as expected. Correct way could be:

assert len(cc) == 2, "Provide a (low, high) range"
lo, hi = cc
if compute_cap not in range(lo, hi + 1):
  raise RuntimeError(...)

Happy to do a PR.

System info (python version, jaxlib version, accelerator, etc.)

System information

jax:    0.4.31.dev20240720                                                                                                                                                                                                         
jaxlib: 0.4.30                                                                                                                                                                                                                     
numpy:  1.26.4                                                                                                                                                                                                                     
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]                                                                                                                                                                         
jax.devices (1 total, 1 local): [cuda(id=0)]                                                                                                                                                                                       
process_count: 1                                                                                                                                                                                                                   
platform: uname_result(system='Linux', node='sagittarius', release='6.5.0-35-generic', version='#35~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue May  7 09:00:52 UTC 2', machine='x86_64')

<truncated>
GPU: RTX-A6000 (Ampere)
Compute Capability: 8.6
Driver: 555.42.06
CUDA: 12.5
cuDNN: 9.2
monatis commented 1 month ago

Hi @MasterSkepticista, I find it quite intersting. To my knowledge, CuDNN's FMHA implementation is only supported on CC 8.0 and 9.0. (I confirmed it with a developer from NVIDIA.) Was the support for CC 8.6 added in a new version? How much speedup did you gain with this, and how did you test it?

MasterSkepticista commented 1 month ago

Flash attention is supported on all Ampere and Hopper GPUs, I think. PyTorch version also works faster on CC 86 (it does not fail their internal check).

Benchmark:

b, t, h, d = 1, 1024, 12, 64
key = jax.random.key(42)
q = k = v = jax.random.uniform(key, (b, t, h, d), jnp.bfloat16)
xla_attn = jax.jit(functools.partial(jax.nn.dot_product_attention, is_causal=True, implementation="xla"))
flash_attn = jax.jit(functools.partial(jax.nn.dot_product_attention, is_causal=True, implementation="cudnn"))

# Warmup.
out_xla = xla_attn(q, k, v)
out_flash = flash_attn(q, k, v)
assert jnp.allclose(out_xla, out_flash, atol=1e-2)
print(jnp.abs(out_xla - out_flash).max())
# 0.00390625 (bfloat16 epsilon?)

%timeit -n 100 xla_attn(q, k, v).block_until_ready()
# 422 μs ± 5.16 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit -n 100 flash_attn(q, k, v).block_until_ready()
# 101 μs ± 4.56 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
hawkinsp commented 1 month ago

From @kaixih, my understanding is that the check is correct. 8.6 is not supported.

MasterSkepticista commented 1 month ago

Hi @hawkinsp, I believe sm86 and sm89 are supported under certain size constraints (for ref: https://github.com/Dao-AILab/flash-attention/issues/138#issuecomment-1647004084, sdp_utils.cpp#L268). Flash attention on sm86 works in PyTorch without any warnings.

Perhaps we can put a check on num_heads or head_dim in the code and allow both sm86 and sm89 to benefit from this speedup.

I would be happy to contribute with a PR.

monatis commented 1 month ago

I reproduced the speedup you posted, but the CuDNN implementation is independent from DaoAILab's implementation. I also confirmed it from an NVIDIA engineer, and they said it's very unlikely to support sm86 and sm89 in the future because compute capabilities with a non-0 minor version have a smaller shared memory and CuDNN's implementation is optimized for larger shared memories in sm80 and sm90. There should be something else happening there.

monatis commented 1 month ago

I figured out that Torch also indicates that CuDNN FMHA implementation is supported only on sm86 and sm89 on line 465. However, they also pack DaoAILab's implementation as a part of Pytroch and have a fallback mechanism like CuDNN FMHA -> DaoAILab's flash attention v2 -> memory-efficient attention implementation.

MasterSkepticista commented 1 month ago

I can check which kernel runs in PyTorch on sm86.

But I think there is merit in allowing JAX cudnn SDPA to run on sm86 and sm89 with a warning. Worst case the smaller shared memory leads to cache misses (which actually doesn't seem to be the case from the microbenchmarks and the fact that it converges to same loss values as xla implementation).

We could also limit the head_dim and num_heads, as PyTorch does here for sm86 and sm89.

What do you think?

monatis commented 1 month ago

I checked the lowering for flash_attn with flash_attn.lower(q, k, v).compile().as_text() and the custom call target is really __cudnn$fmhaSoftmax. Pretty interesting.

kaixih commented 1 month ago

@Cjkkkk Can you comment on the version restrictions of cudnn flash attn?

Cjkkkk commented 1 month ago

Thanks for bringing this up. I think the constraint of sm86/sm89 is for non flash version attention which is removed now from jax cudnn SDPA API. Confirming this with Nvidia cudnn team now to see if we can relax the constraint for sm86/sm89 for flash attention.

Cjkkkk commented 1 month ago

Created a pr to include sm86/sm89 as well: https://github.com/google/jax/pull/22605

MasterSkepticista commented 1 month ago

Will close, now that #22605 is merged.