pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
340 stars 13 forks source link

V100 GPUs supported ? #21

Open boren-ms opened 1 month ago

boren-ms commented 1 month ago

do we support the old V100 GPUs with FlexAttention ?

drisspg commented 1 month ago

FlexAttention relies entirely on Triton for kernel support, so it support is fundamentally limited by what Triton supports. I think in theory V100 is supported on V100 but anecdotally I know that this support (in terms of perf) is not great.

SimJeg commented 1 week ago

Hello @drisspg,

I ran the following speed benchmark on V100 GPUs and found that flex_attention is 10x slower than SDPA and eager implementation. Am I doing something wrong ?

pytorch-triton            3.0.0+dedb7bdf33
torch                     2.5.0.dev20240901+cu124
# pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention.flex_attention import flex_attention

from time import time

device = "cuda"
dtype = torch.float16
bsz, n_heads, n_tokens, n_dims = 16, 32, 1024, 128
n_repeats = 100

attention_dict = {
    "eager": lambda query, key, value: torch.matmul(torch.nn.functional.softmax(torch.matmul(query, key.transpose(-2, -1)) / (n_dims ** 0.5), dim=-1), value),
    "sdpa": scaled_dot_product_attention,
    "flex": flex_attention,
}

time_dict = {}

with torch.no_grad():
    for i in range(n_repeats):
        query, key, value = torch.randn(3, bsz, n_heads, n_tokens, n_dims).to(device, dtype)
        torch.cuda.synchronize()

        for name, fn in attention_dict.items():
            start = time()
            output = fn(query, key, value)
            torch.cuda.synchronize()
            end = time()
            time_dict[name] = time_dict.get(name, 0) + end - start

for name, time_taken in time_dict.items():
    print(f"{name}: {1000*time_taken/n_repeats:.2f}ms")
eager: 14.34ms
sdpa: 10.79ms
flex: 107.65ms

Note that with float32, I get output(flex) == output(eager) while I have output(sdpa) = output(eager) +- 1e-8.

drisspg commented 1 week ago

@SimJeg I think that this is unrelated to this issue. FlexAttention is not expected to be performant in eager mode because it falls back to a decomposed implementation. The fused triton kernel is only enabled by compiling the function torch.compile(flex_attention

SimJeg commented 1 week ago

Thanks for your answer. Unrelated question: every time I change block_mask, I need to recompile flex_attention ? (e.g. if I use a block mask that depends on the input keys and queries)

drisspg commented 1 week ago

Kinda depends, do you have an example of how you are creating the BlockMask with what mask_mod

If your mask_mod is depdent on the data in another buffer or you change the sequence lengths you will need to recreate