jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.56k stars 2.81k forks source link

[GPU] FlashAttention performance lags behind PyTorch #24934

Open neel04 opened 5 days ago

neel04 commented 5 days ago

Description

I'm benchmarking naive FlashAttention in Jax vs. the Pallas's version of FA3 vs. the new dot_product_attention interface with cudnn backend.

image

image

Why the discrepancy? I'd have expected performance to touch 550-600 TFLOPS/s. I'm using a few XLA flags, as specified in the script below - but is there anything I'm missing? Or is this about the maximum XLA can deliver on H100s?

Steps to reproduce

  1. Recreate the environment using uv. I'm assuming the drivers are installed. If not, you can use the pytorch/pytorch:2.4.0-cuda12.4.1-cudnn8-runtime image on the GPU, run the preliminary apt-get update and apt-get upgrade to set everything up.
pip3 install uv
uv venv 'main_env' --python 3.11
source main_env/bin/activate

uv pip install -U "jax[cuda12]"
uv pip install -q einops tqdm jaxtyping optax optuna equinox rich
uv pip install -q nvitop pdbpp tabulate
  1. Run either script
**JAX script**

```py import os, sys os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9' os.environ['XLA_FLAGS'] = ( '--xla_gpu_enable_triton_softmax_fusion=true ' '--xla_gpu_enable_cudnn_fmha=true' ) import math import time from tabulate import tabulate import jax import jax.numpy as jnp import numpy as np from jax.tree_util import tree_map from jax.experimental.pallas.ops.gpu.attention import mha as pallas_mha from functools import partial class Timer(object): def __init__(self, into=None): self.into = into def __enter__(self): self.start = time.time() def __exit__(self, type, value, traceback): if self.into is not None: self.into.append(time.time() - self.start) def elapsed(self): return time.time() - self.start def cartesian(*lists): if lists: xs = lists[0] for x in xs: for rest in cartesian(*lists[1:]): yield (x,) + rest else: yield () def cross_attn_flops(B,T,TK,H,C): HC = H*C # T = TK for self-attention flops_fwd = ( 2 * B*H*T*TK*C + # S = Q@K 3 * B*H*T*TK + # P=softmax(S) 2 * B*T*TK*H*C # O = P@V ) return flops_fwd def attn_flops(B,T,H,C): return cross_attn_flops(B,T,T,H,C) dtype = jnp.float16 print(f'Using dtype: {dtype}') def convert(xs): return tree_map(lambda x: x.astype(dtype), xs) def ref_fwd(q,k,v): # reference implementation [n, l, h, d] = q.shape [n, lk, hk, d] = k.shape softmax_scale = 1 / math.sqrt(d) S = jnp.einsum('nlhd,nLhd->nhlL',q,k) P = jax.nn.softmax(S*softmax_scale, axis=-1) o = jnp.einsum('nhlL,nLhd->nlhd',P,v) return o.astype(q.dtype) def jax_dpa_fwd(q, k, v): output = jax.nn.dot_product_attention( query=q, key=k, value=v, implementation='cudnn' ) return output # ---- Bx = [8, 16] # batch size Tx = [1024, 2048] # seqlen Hx = [16, 32] # number of heads Cx = [64, 128] # head dim sx = [2] # steps def bench_attn(mha): @jax.jit def bench(q, k, v, steps: int): for i in steps: out = mha(q, k, v) return out times = [] table = {} for B,T,H,C,s in cartesian(Bx,Tx,Hx,Cx,sx): q = jax.random.normal(jax.random.PRNGKey(0), [B, T, H, C], dtype=dtype) k = jax.random.normal(jax.random.PRNGKey(1), [B, T, H, C], dtype=dtype) v = jax.random.normal(jax.random.PRNGKey(2), [B, T, H, C], dtype=dtype) steps = jnp.arange(s) # Warmup for _ in range(2): bench(q, k, v, steps).block_until_ready() out = [] for _ in range(8): with Timer(out): bench(q, k, v, steps).block_until_ready() t = sum(out[2:])/len(out[2:]) flop = attn_flops(B,T,H,C) * s tf = flop / t / 1e12 / s print(f'flops ({B} {T} {H} {C} / {s}): {tf}T', out) table[(B,T,H,C,s)] = tf times.append(t) return table, times naive_flops, custom_time = bench_attn(ref_fwd) pallas_flops, pallas_time = bench_attn(partial(pallas_mha, segment_ids=None)) jax_dpa_flops, dpa_time = bench_attn(jax_dpa_fwd) table = [] for idx, (B,T,H,C,s) in enumerate(cartesian(Bx,Tx,Hx,Cx,sx)): n_flops, n_time = naive_flops[(B,T,H,C,s)], custom_time[idx] p_flops, p_time = pallas_flops[(B,T,H,C,s)], pallas_time[idx] j_flops, j_time = jax_dpa_flops[(B,T,H,C,s)], dpa_time[idx] table.append((B,T,H,C, n_flops, p_flops, j_flops, attn_flops(B,T,H,C), n_time, p_time, j_time)) print(tabulate(table, headers=['B','T','H','C','TFlop/s (naive)','TFlop/s (pallas)','TFlop/s (jax_dpa)', 'FLOPs', 'Naive time', 'Pallas time', 'DPA time'], floatfmt='.5f')) ```

**PyTorch Benchmark script**

```py import torch torch._dynamo.config.cache_size_limit = 10000 # Increase cache size to 10,000 import time from tabulate import tabulate from typing import Optional from itertools import product from torch.utils.flop_counter import FlopCounterMode class Timer: def __init__(self, into=None): self.into = into def __enter__(self): self.start = time.time() def __exit__(self, type, value, traceback): if self.into is not None: self.into.append(time.time() - self.start) class OptimizedMHA(torch.nn.Module): def __init__(self): super().__init__() def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: B, T, H, C = q.shape scale = 1.0 / (C ** 0.5) # Compute attention scores attn = torch.matmul(q, k.transpose(-2, -1)) * scale # [B, H, T, T] attn = torch.softmax(attn, dim=-1) # Apply attention to values out = torch.matmul(attn, v) # [B, H, T, C] out = out.transpose(1, 2) # [B, T, H, C] return out class TorchMHA(torch.nn.Module): def __init__(self): super().__init__() def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: return torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=0.0, is_causal=False ) def get_flops(model, sample_inputs): """Get FLOPs using FlopCounterMode""" with FlopCounterMode(model) as flop_counter: _ = model(*sample_inputs) return flop_counter.get_total_flops() def bench_implementation(model, q, k, v, warmup=10, steps=100): """Benchmark a specific implementation""" # Warmup for _ in range(warmup): with torch.no_grad(): _ = model(q, k, v) torch.cuda.synchronize() # Benchmark timing times = [] for _ in range(steps): with torch.no_grad(), Timer(times): _ = model(q, k, v) torch.cuda.synchronize() # Calculate statistics times = times[10:] # Discard first 10 runs avg_time = sum(times) / len(times) return avg_time def bench_attention(): # Configuration device = torch.device("cuda") dtype = torch.float16 # Test parameters Bx = [8, 16] # batch size Tx = [1024, 2048] # sequence length Hx = [16, 32] # number of heads Cx = [64, 128] # head dimension sx = [4] # steps to run # Initialize models custom_model = OptimizedMHA().to(device) torch_model = TorchMHA().to(device) # Compile both models compiled_custom = torch.compile( custom_model, mode="max-autotune-no-cudagraphs", fullgraph=True, ) compiled_torch = torch.compile( torch_model, mode="max-autotune-no-cudagraphs", fullgraph=True, ) results = [] # Run benchmarks for each configuration for B, T, H, C, s in product(Bx, Tx, Hx, Cx, sx): # Create input tensors q = torch.randn(B, T, H, C, device=device, dtype=dtype) k = torch.randn(B, T, H, C, device=device, dtype=dtype) v = torch.randn(B, T, H, C, device=device, dtype=dtype) q = q.transpose(1, 2) # [B, H, T, C] k = k.transpose(1, 2) # [B, H, T, C] v = v.transpose(1, 2) # [B, H, T, C] # Get FLOPs using FlopCounterMode (on CPU with float32) model_cpu = OptimizedMHA() q_cpu = q.cpu().float() k_cpu = k.cpu().float() v_cpu = v.cpu().float() flops = get_flops(model_cpu, (q_cpu, k_cpu, v_cpu)) # Benchmark both implementations custom_time = bench_implementation(compiled_custom, q, k, v) torch_time = bench_implementation(compiled_torch, q, k, v) # Calculate TFLOPs/s for both custom_tflops = flops / custom_time / 1e12 torch_tflops = flops / torch_time / 1e12 # Calculate speedup speedup = custom_time / torch_time # >1 means torch is faster print(f"\nConfig (B={B}, T={T}, H={H}, C={C}):") print(f"Custom impl: {custom_tflops:.2f} TFlop/s") print(f"Torch impl: {torch_tflops:.2f} TFlop/s") print(f"Speedup (Torch vs Custom): {speedup:.2f}x") print(f"Measured FLOPs: {flops:,}") results.append(( B, T, H, C, round(custom_tflops, 2), round(torch_tflops, 2), round(speedup, 2), flops, custom_time, torch_time )) # Print results table headers = [ 'Batch', 'SeqLen', 'Heads', 'HeadDim', 'Naive MHA TFlop/s', 'SDPA TFlop/s', 'Advantage', 'FLOPs', 'Custom Time', 'SDPA Time' ] print("\nResults:") print(tabulate(results, headers=headers, floatfmt='.5f')) if __name__ == "__main__": bench_attention() ```


System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.35
jaxlib: 0.4.34
numpy:  2.1.3
python: 3.11.10 (main, Oct 16 2024, 04:38:48) [Clang 18.1.8 ]
device info: NVIDIA H100 PCIe-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='nifty-orthodox-whale', release='6.8.0-40-generic', version='#40~22.04.3-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 30 17:30:19 UTC 2', machine='x86_64')

$ nvidia-smi
Sun Nov 17 02:04:44 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.06             Driver Version: 535.183.06   CUDA Version: 12.4     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA H100 PCIe               On  | 00000000:00:07.0 Off |                    0 |
| N/A   31C    P0              54W / 350W |    467MiB / 81559MiB |      2%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
neel04 commented 5 days ago

Update: I changed the torch script to use FlopCounterMode. Now the results are more realistic/accurate but JAX still lags behind despite explicitly being forced to use CuDNN.

cc @kaixih @sbodenstein @dfm

Rick0827 commented 4 days ago

i think your torch script might not work as expected, since the inputs format of torch.nn.functional.scaled_dot_product_attention is [B, H, T, C] [https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html]

63aaff67935ab7ebb1dc5feb4226787
sbodenstein commented 4 days ago

This is a complicated benchmarking setup, with many things potentially going wrong. Can you simplify this to just measuring milliseconds, and also have a correctness test (that PyTorch and JAX give the same output for the same input).

neel04 commented 3 days ago

@Rick0827 Thank you for pointing that out.

@sbodenstein I have updated both scripts to now report times as well. However, I opted to skip correctness tests because reproducibility requires us to sacrifice performance which I'm afraid of touching

The variance however is very low between runs plus we can average over multiple steps (sx) so this should be a non-issue.

On A100:

image

image