Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.23k stars 1.33k forks source link

v2.6.3's flash_attn_varlen_func runs faster than v2.7.0.post2's flash_Attn_varlen_func on H100 #1338

Open complexfilter opened 1 hour ago

complexfilter commented 1 hour ago

I found v2.6.3's flash_attn_varlen_func runs faster than v2.7.0.post2's flash_Attn_varlen_func on H100.

code

import torch

from hopper.flash_attn_interface import flash_attn_func, flash_attn_varlen_func

import triton

def get_tensors(batch_size, seq_len, head_size, dim):
    torch.manual_seed(42)
    q = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
    k = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
    v = torch.randn((batch_size, seq_len, head_size, dim), dtype=torch.bfloat16, device="cuda", requires_grad=True)
    return q, k, v

if __name__=="__main__":
    batch_size = 1
    seq_len = 33+40*34*60
    head_size = 28
    dim = 128
    q, k, v = get_tensors(batch_size, seq_len, head_size, dim)
    fn = lambda : flash_attn_func(q, k, v, softmax_scale=None, causal=False)
    warmup=20
    rep=100
    ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
    print(f"FA3 fwd with full attention: {ms[1]:.5f} ({ms[0]:.5f}, {ms[2]:.5f}) | ")

    cu_seqlens_q = torch.tensor([0, seq_len], device = q.device, dtype=torch.int32)
    cu_seqlens_k = torch.tensor([0, seq_len], device = q.device, dtype=torch.int32)
    fn = lambda : flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q, cu_seqlens_k, seq_len, seq_len)
    warmup=20
    rep=100
    ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=[0.2, 0.5, 0.8])
    print(f"FA3 varlen fwd with full attention: {ms[1]:.5f} ({ms[0]:.5f}, {ms[2]:.5f}) | ")

    x  = flash_attn_func(q, k, v, softmax_scale=None, causal=False, deterministic=False)
    x_ = flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q, cu_seqlens_k, seq_len, seq_len, softmax_scale=None, causal=False, deterministic=False)

    print(torch.norm(x_[0].unsqueeze(0) - x[0], p='fro') / torch.norm(x[0], p='fro'))

Result from using v2.6.3 on H100:

FA3 fwd with full attention: 147.76518 (147.76518, 147.76518) | 
FA3 varlen fwd with full attention: 150.84227 (150.84227, 150.84227) | 
tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<DivBackward0>)

Result from using v2.7.0.post2 on H100:

FA3 fwd with full attention: 155.47057 (155.47057, 155.47057) | 
FA3 varlen fwd with full attention: 221.30321 (221.30321, 221.30321) | 
tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<DivBackward0>)

The runtime is 150ms vs 221ms.

tridao commented 1 hour ago

Please try compiling with CUDA 12.3

complexfilter commented 1 hour ago

Please try compiling with CUDA 12.3

I believe my cuda version is 12.4.

Fri Nov 15 16:50:13 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.90.07              Driver Version: 550.90.07      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:86:00.0 Off |                    0 |
| N/A   26C    P0             80W /  700W |       1MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Not sure if CUDA 12.4 was the issue.

tridao commented 1 hour ago

What matters is is the version of nvcc, not the CUDA driver. You can install cuda software toolkit (including nvcc) to whichever driver version