Open boren-ms opened 1 month ago
FlexAttention relies entirely on Triton for kernel support, so it support is fundamentally limited by what Triton supports. I think in theory V100 is supported on V100 but anecdotally I know that this support (in terms of perf) is not great.
Hello @drisspg,
I ran the following speed benchmark on V100 GPUs and found that flex_attention
is 10x slower than SDPA and eager implementation. Am I doing something wrong ?
pytorch-triton 3.0.0+dedb7bdf33
torch 2.5.0.dev20240901+cu124
# pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention.flex_attention import flex_attention
from time import time
device = "cuda"
dtype = torch.float16
bsz, n_heads, n_tokens, n_dims = 16, 32, 1024, 128
n_repeats = 100
attention_dict = {
"eager": lambda query, key, value: torch.matmul(torch.nn.functional.softmax(torch.matmul(query, key.transpose(-2, -1)) / (n_dims ** 0.5), dim=-1), value),
"sdpa": scaled_dot_product_attention,
"flex": flex_attention,
}
time_dict = {}
with torch.no_grad():
for i in range(n_repeats):
query, key, value = torch.randn(3, bsz, n_heads, n_tokens, n_dims).to(device, dtype)
torch.cuda.synchronize()
for name, fn in attention_dict.items():
start = time()
output = fn(query, key, value)
torch.cuda.synchronize()
end = time()
time_dict[name] = time_dict.get(name, 0) + end - start
for name, time_taken in time_dict.items():
print(f"{name}: {1000*time_taken/n_repeats:.2f}ms")
eager: 14.34ms
sdpa: 10.79ms
flex: 107.65ms
Note that with float32, I get output(flex) == output(eager) while I have output(sdpa) = output(eager) +- 1e-8.
@SimJeg I think that this is unrelated to this issue. FlexAttention is not expected to be performant in eager mode because it falls back to a decomposed implementation. The fused triton kernel is only enabled by compiling the function torch.compile(flex_attention
Thanks for your answer. Unrelated question: every time I change block_mask, I need to recompile flex_attention ? (e.g. if I use a block mask that depends on the input keys and queries)
Kinda depends, do you have an example of how you are creating the BlockMask with what mask_mod
If your mask_mod is depdent on the data in another buffer or you change the sequence lengths you will need to recreate
do we support the old V100 GPUs with FlexAttention ?