Open QinlongHuang opened 1 week ago
Hi,
memory_efficient_attention
used to be faster than PyTorch's SDPA because xFormers was using Flash-Attention. Now SDPA is also using Flash-Attention, so it's normal to have the same speed.
Also the dimensions need to be transposed between SDPA (format BHMK) and memory_efficient_attention
(format BMHK).
Thank u so much for the quick reply!
So is there any gain to use memory_efficient_attention
of xformers instead of PyTorch's SDPA now?
And it seems that memory_efficient_attention
is not compatiable w/ torch.compile
which can speed up training and use less memory.
Besides, you've mentioned that "the dimensions need to be transposed between SDPA (format BHMK) and memory_efficient_attention (format BMHK)". So I have to transpose the QKV to get the CORRECT results?
I did a toy test w/ the following snippet.
import torch
import torch.nn.functional as F
from xformers.ops import memory_efficient_attention
batch_size, num_heads, seq_length, head_dim = 32, 128, 512, 256
q = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda') # BHMK
k = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda')
v = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda')
# PyTorch scaled_dot_product_attention
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for i in range(10):
output_torch = F.scaled_dot_product_attention(q, k, v)
end.record()
torch.cuda.synchronize()
print(f"PyTorch scaled_dot_product_attention: {start.elapsed_time(end)} ms")
# xformers ScaledDotProduct
q = q.transpose(1, 2) # BMHK
k = k.transpose(1, 2)
v = v.transpose(1, 2)
torch.cuda.synchronize()
start.record()
for i in range(10):
output_xformers = memory_efficient_attention(q, k, v)
end.record()
torch.cuda.synchronize()
print(f"xformers memory_efficient_attention: {start.elapsed_time(end)} ms")
And now I get the similiar speed w/ these two implementations.
❓ Questions and Help
I am new to xformers, and I want to speed my Transformer models w/ it. But I found that
xformers
is no speed up compared w/scaled_dot_product_attention
from PyTorch. Here is my code snippet for training a vanilla GPT-2. Is there anywhing wrong when I use xformers?Environment: Ubuntu 20.04 CUDA11.8 NVIDIA RTX 4090, PyTorch 2.4.1, xformers 0.0.28.post1
When I trained w/ a standard GPT-2 (~89M parameters) using
scaled_dot_product_attention
, I got ~9it/s, but only ~7it/s onmemory_efficient_attention
.And I cannot train a GPT-2-medium (~300M parameters) when using
memory_efficient_attention
, but I can train that w/scaled_dot_product_attention
.All exps are trained using fp16 and w/
torch.compile
.