Open neel04 opened 5 days ago
Update: I changed the torch
script to use FlopCounterMode
. Now the results are more realistic/accurate but JAX still lags behind despite explicitly being forced to use CuDNN
.
cc @kaixih @sbodenstein @dfm
i think your torch script might not work as expected, since the inputs format of torch.nn.functional.scaled_dot_product_attention
is [B, H, T, C] [https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html]
This is a complicated benchmarking setup, with many things potentially going wrong. Can you simplify this to just measuring milliseconds, and also have a correctness test (that PyTorch and JAX give the same output for the same input).
@Rick0827 Thank you for pointing that out.
@sbodenstein I have updated both scripts to now report times as well. However, I opted to skip correctness tests because reproducibility requires us to sacrifice performance which I'm afraid of touching
The variance however is very low between runs plus we can average over multiple steps (sx
) so this should be a non-issue.
On A100
:
Description
I'm benchmarking naive FlashAttention in
Jax
vs. the Pallas's version ofFA3
vs. the newdot_product_attention
interface withcudnn
backend.Why the discrepancy? I'd have expected performance to touch 550-600 TFLOPS/s. I'm using a few XLA flags, as specified in the script below - but is there anything I'm missing? Or is this about the maximum
XLA
can deliver on H100s?Steps to reproduce
uv
. I'm assuming the drivers are installed. If not, you can use thepytorch/pytorch:2.4.0-cuda12.4.1-cudnn8-runtime
image on the GPU, run the preliminaryapt-get update
andapt-get upgrade
to set everything up.**JAX script**
```py import os, sys os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.9' os.environ['XLA_FLAGS'] = ( '--xla_gpu_enable_triton_softmax_fusion=true ' '--xla_gpu_enable_cudnn_fmha=true' ) import math import time from tabulate import tabulate import jax import jax.numpy as jnp import numpy as np from jax.tree_util import tree_map from jax.experimental.pallas.ops.gpu.attention import mha as pallas_mha from functools import partial class Timer(object): def __init__(self, into=None): self.into = into def __enter__(self): self.start = time.time() def __exit__(self, type, value, traceback): if self.into is not None: self.into.append(time.time() - self.start) def elapsed(self): return time.time() - self.start def cartesian(*lists): if lists: xs = lists[0] for x in xs: for rest in cartesian(*lists[1:]): yield (x,) + rest else: yield () def cross_attn_flops(B,T,TK,H,C): HC = H*C # T = TK for self-attention flops_fwd = ( 2 * B*H*T*TK*C + # S = Q@K 3 * B*H*T*TK + # P=softmax(S) 2 * B*T*TK*H*C # O = P@V ) return flops_fwd def attn_flops(B,T,H,C): return cross_attn_flops(B,T,T,H,C) dtype = jnp.float16 print(f'Using dtype: {dtype}') def convert(xs): return tree_map(lambda x: x.astype(dtype), xs) def ref_fwd(q,k,v): # reference implementation [n, l, h, d] = q.shape [n, lk, hk, d] = k.shape softmax_scale = 1 / math.sqrt(d) S = jnp.einsum('nlhd,nLhd->nhlL',q,k) P = jax.nn.softmax(S*softmax_scale, axis=-1) o = jnp.einsum('nhlL,nLhd->nlhd',P,v) return o.astype(q.dtype) def jax_dpa_fwd(q, k, v): output = jax.nn.dot_product_attention( query=q, key=k, value=v, implementation='cudnn' ) return output # ---- Bx = [8, 16] # batch size Tx = [1024, 2048] # seqlen Hx = [16, 32] # number of heads Cx = [64, 128] # head dim sx = [2] # steps def bench_attn(mha): @jax.jit def bench(q, k, v, steps: int): for i in steps: out = mha(q, k, v) return out times = [] table = {} for B,T,H,C,s in cartesian(Bx,Tx,Hx,Cx,sx): q = jax.random.normal(jax.random.PRNGKey(0), [B, T, H, C], dtype=dtype) k = jax.random.normal(jax.random.PRNGKey(1), [B, T, H, C], dtype=dtype) v = jax.random.normal(jax.random.PRNGKey(2), [B, T, H, C], dtype=dtype) steps = jnp.arange(s) # Warmup for _ in range(2): bench(q, k, v, steps).block_until_ready() out = [] for _ in range(8): with Timer(out): bench(q, k, v, steps).block_until_ready() t = sum(out[2:])/len(out[2:]) flop = attn_flops(B,T,H,C) * s tf = flop / t / 1e12 / s print(f'flops ({B} {T} {H} {C} / {s}): {tf}T', out) table[(B,T,H,C,s)] = tf times.append(t) return table, times naive_flops, custom_time = bench_attn(ref_fwd) pallas_flops, pallas_time = bench_attn(partial(pallas_mha, segment_ids=None)) jax_dpa_flops, dpa_time = bench_attn(jax_dpa_fwd) table = [] for idx, (B,T,H,C,s) in enumerate(cartesian(Bx,Tx,Hx,Cx,sx)): n_flops, n_time = naive_flops[(B,T,H,C,s)], custom_time[idx] p_flops, p_time = pallas_flops[(B,T,H,C,s)], pallas_time[idx] j_flops, j_time = jax_dpa_flops[(B,T,H,C,s)], dpa_time[idx] table.append((B,T,H,C, n_flops, p_flops, j_flops, attn_flops(B,T,H,C), n_time, p_time, j_time)) print(tabulate(table, headers=['B','T','H','C','TFlop/s (naive)','TFlop/s (pallas)','TFlop/s (jax_dpa)', 'FLOPs', 'Naive time', 'Pallas time', 'DPA time'], floatfmt='.5f')) ```
**PyTorch Benchmark script**
```py import torch torch._dynamo.config.cache_size_limit = 10000 # Increase cache size to 10,000 import time from tabulate import tabulate from typing import Optional from itertools import product from torch.utils.flop_counter import FlopCounterMode class Timer: def __init__(self, into=None): self.into = into def __enter__(self): self.start = time.time() def __exit__(self, type, value, traceback): if self.into is not None: self.into.append(time.time() - self.start) class OptimizedMHA(torch.nn.Module): def __init__(self): super().__init__() def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: B, T, H, C = q.shape scale = 1.0 / (C ** 0.5) # Compute attention scores attn = torch.matmul(q, k.transpose(-2, -1)) * scale # [B, H, T, T] attn = torch.softmax(attn, dim=-1) # Apply attention to values out = torch.matmul(attn, v) # [B, H, T, C] out = out.transpose(1, 2) # [B, T, H, C] return out class TorchMHA(torch.nn.Module): def __init__(self): super().__init__() def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: return torch.nn.functional.scaled_dot_product_attention( q, k, v, dropout_p=0.0, is_causal=False ) def get_flops(model, sample_inputs): """Get FLOPs using FlopCounterMode""" with FlopCounterMode(model) as flop_counter: _ = model(*sample_inputs) return flop_counter.get_total_flops() def bench_implementation(model, q, k, v, warmup=10, steps=100): """Benchmark a specific implementation""" # Warmup for _ in range(warmup): with torch.no_grad(): _ = model(q, k, v) torch.cuda.synchronize() # Benchmark timing times = [] for _ in range(steps): with torch.no_grad(), Timer(times): _ = model(q, k, v) torch.cuda.synchronize() # Calculate statistics times = times[10:] # Discard first 10 runs avg_time = sum(times) / len(times) return avg_time def bench_attention(): # Configuration device = torch.device("cuda") dtype = torch.float16 # Test parameters Bx = [8, 16] # batch size Tx = [1024, 2048] # sequence length Hx = [16, 32] # number of heads Cx = [64, 128] # head dimension sx = [4] # steps to run # Initialize models custom_model = OptimizedMHA().to(device) torch_model = TorchMHA().to(device) # Compile both models compiled_custom = torch.compile( custom_model, mode="max-autotune-no-cudagraphs", fullgraph=True, ) compiled_torch = torch.compile( torch_model, mode="max-autotune-no-cudagraphs", fullgraph=True, ) results = [] # Run benchmarks for each configuration for B, T, H, C, s in product(Bx, Tx, Hx, Cx, sx): # Create input tensors q = torch.randn(B, T, H, C, device=device, dtype=dtype) k = torch.randn(B, T, H, C, device=device, dtype=dtype) v = torch.randn(B, T, H, C, device=device, dtype=dtype) q = q.transpose(1, 2) # [B, H, T, C] k = k.transpose(1, 2) # [B, H, T, C] v = v.transpose(1, 2) # [B, H, T, C] # Get FLOPs using FlopCounterMode (on CPU with float32) model_cpu = OptimizedMHA() q_cpu = q.cpu().float() k_cpu = k.cpu().float() v_cpu = v.cpu().float() flops = get_flops(model_cpu, (q_cpu, k_cpu, v_cpu)) # Benchmark both implementations custom_time = bench_implementation(compiled_custom, q, k, v) torch_time = bench_implementation(compiled_torch, q, k, v) # Calculate TFLOPs/s for both custom_tflops = flops / custom_time / 1e12 torch_tflops = flops / torch_time / 1e12 # Calculate speedup speedup = custom_time / torch_time # >1 means torch is faster print(f"\nConfig (B={B}, T={T}, H={H}, C={C}):") print(f"Custom impl: {custom_tflops:.2f} TFlop/s") print(f"Torch impl: {torch_tflops:.2f} TFlop/s") print(f"Speedup (Torch vs Custom): {speedup:.2f}x") print(f"Measured FLOPs: {flops:,}") results.append(( B, T, H, C, round(custom_tflops, 2), round(torch_tflops, 2), round(speedup, 2), flops, custom_time, torch_time )) # Print results table headers = [ 'Batch', 'SeqLen', 'Heads', 'HeadDim', 'Naive MHA TFlop/s', 'SDPA TFlop/s', 'Advantage', 'FLOPs', 'Custom Time', 'SDPA Time' ] print("\nResults:") print(tabulate(results, headers=headers, floatfmt='.5f')) if __name__ == "__main__": bench_attention() ```
System info (python version, jaxlib version, accelerator, etc.)