Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.02k stars 1.31k forks source link

FA2' flash_attn_varlen_func is 300x slower than flash_attn_func #1125

Open ex3ndr opened 3 months ago

ex3ndr commented 3 months ago

Hello! I am benchmarking attention implementations, and trying to use flash attention for my variable length data and for some reason var length is much much slower than any other implementation, i am testing on 4090. No matter how much i am warming, or retrying it is always ~same results.

Benchmark results:

flash (mask) 0.9587182998657227
flash (no mask) 0.0034220218658447266
xformers (mask) 0.013103961944580078
xformers (no mask) 0.005987882614135742
torch (mask) 0.009450435638427734
torch (no mask) 0.0041005611419677734

This is the code:

import torch
import torchaudio
import torch.utils.benchmark as benchmark
import xformers.ops as xops
from xformers.ops import fmha
import time
from einops import rearrange, repeat, reduce, pack, unpack
from flash_attn import flash_attn_func, flash_attn_varlen_func

#
# Parameters
#

n_heads = 64
n_len = 10000
q = torch.rand(n_len, 128 * 64, dtype=torch.float16, device="cuda")
k = torch.rand(n_len, 128 * 64, dtype=torch.float16, device="cuda")
v = torch.rand(n_len, 128 * 64, dtype=torch.float16, device="cuda")
lengths = [1000, 2520, 2520, 2520, 1440]

#
# Implementations
# 

def xformers_attention(q, k, v, mask):
    q, k, v = map(lambda t: rearrange(t, 'n (h d) -> 1 n h d', h = n_heads), (q, k, v))
    y = xops.memory_efficient_attention(q, k, v, attn_bias = mask)
    y = rearrange(y, '1 n h d -> n (h d)')
    return y
def xformers_mask(lengths):
    return fmha.BlockDiagonalMask.from_seqlens(lengths)
def torch_attention(q, k, v, mask):
    q, k, v = map(lambda t: rearrange(t, 'n (h d) -> 1 h n d', h = n_heads), (q, k, v))
    y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask = mask)
    y = rearrange(y, '1 h n d -> n (h d)')
    return y
def torch_mask(lengths):
    return fmha.BlockDiagonalMask.from_seqlens(lengths).materialize(shape=(n_len, n_len)).cuda().to(torch.float16)
def flash_attention(q, k, v, mask):
    if mask is None:
        q, k, v = map(lambda t: rearrange(t, 'n (h d) -> 1 n h d', h = n_heads), (q, k, v))
        y = flash_attn_func(q, k, v)
        y = rearrange(y, '1 n h d -> n (h d)')
        return y
    else:
        q, k, v = map(lambda t: rearrange(t, 'n (h d) -> n h d', h = n_heads), (q, k, v))
        y = flash_attn_varlen_func(q, k, v, cu_seqlens_q=mask[1], cu_seqlens_k=mask[1], max_seqlen_q=mask[0], max_seqlen_k=mask[0])
        y = rearrange(y, 'n h d -> n (h d)')
        return y
def flash_mask(lengths):
    seq_lens = torch.tensor(lengths, dtype=torch.int32)
    max_seq_len = seq_lens.max().item()
    return (torch.tensor(max_seq_len, dtype=torch.int32).cuda(), torch.concat([torch.tensor([0]), seq_lens.cumsum(0)]).to(torch.int32).cuda())

#
# Check implementations
# 

torch_m = torch_mask(lengths)
flash_m = flash_mask(lengths)
xformers_m = xformers_mask(lengths)
a = xformers_attention(q, k, v, xformers_m)
b = torch_attention(q, k, v, torch_m)
c = flash_attention(q, k, v, flash_m)
print((a - b).abs().max())
print((a - c).abs().max())
print((b - c).abs().max())

#
# Benchmarking
# 

start = time.time()
for i in range(100):
    xformers_attention(q, k, v, xformers_m)
print("xformers (mask)", time.time() - start)
start = time.time()
for i in range(100):
    xformers_attention(q, k, v, None)
print("xformers (no mask)", time.time() - start)
start = time.time()
for i in range(100):
    torch_attention(q, k, v, torch_m)
print("torch (mask)", time.time() - start)
start = time.time()
for i in range(100):
    torch_attention(q, k, v, None)
print("torch (no mask)", time.time() - start)
start = time.time()
for i in range(100):
    flash_attention(q, k, v, flash_m)
print("flash (mask)", time.time() - start)
start = time.time()
for i in range(100):
    flash_attention(q, k, v, None)
print("flash (no mask)", time.time() - start)
ex3ndr commented 3 months ago

Changing to 16x16 head dimensions reduces gap to 10x, but still very slow.

tridao commented 3 months ago

Please don't use time.time() to measure time. CUDA operations are async. You can use torch benchmark. https://pytorch.org/tutorials/recipes/recipes/benchmark.html

ex3ndr commented 3 months ago

@tridao Thank you for catching that, after the fix it is still 4x slower than flash_attn_func:

xformers (mask) 0.00034342713200021533 xformers (no mask) 0.0013367030000081285 torch (mask) 0.0034441131959902123 torch (no mask) 0.0013596494959783741 flash (mask) 0.00034348745294846597 flash (no mask) 0.0013394619610044174

zhangjun commented 2 months ago

@ex3ndr You should add warmup runs like this. https://github.com/triton-lang/triton/blob/fd0fa8305c8626dd77cf588336ccdceabe7d8230/python/triton/testing.py#L144

ex3ndr commented 2 months ago

@zhangjun Thanks! But i am running this code in notebook and repeating cell execution yields similar results.