facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.41k stars 597 forks source link

`memory_efficient_attention` is slower than `scaled_dot_product_attention` of PyTorch? #1107

Open QinlongHuang opened 1 week ago

QinlongHuang commented 1 week ago

❓ Questions and Help

I am new to xformers, and I want to speed my Transformer models w/ it. But I found that xformers is no speed up compared w/ scaled_dot_product_attention from PyTorch. Here is my code snippet for training a vanilla GPT-2. Is there anywhing wrong when I use xformers?

from xformers.ops import memory_efficient_attention, LowerTriangularMask

self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
if memory_efficient_attention is not None:
            y = memory_efficient_attention(
                q, k, v, 
                p=self.dropout if self.training else 0,
                attn_bias=LowerTriangularMask(),
            )
elif self.flash:
    y = F.scaled_dot_product_attention(
        q, k, v,
        dropout_p=self.dropout if self.training else 0,
        is_causal=True,
    )

Environment: Ubuntu 20.04 CUDA11.8 NVIDIA RTX 4090, PyTorch 2.4.1, xformers 0.0.28.post1

python -m xformers.info

xFormers 0.0.28.post1
memory_efficient_attention.ckF:                    unavailable
memory_efficient_attention.ckB:                    unavailable
memory_efficient_attention.ck_decoderF:            unavailable
memory_efficient_attention.ck_splitKF:             unavailable
memory_efficient_attention.cutlassF-pt:            available
memory_efficient_attention.cutlassB-pt:            available
memory_efficient_attention.fa2F@v2.5.6-pt:         available
memory_efficient_attention.fa2B@v2.5.6-pt:         available
memory_efficient_attention.fa3F@0.0.0:             unavailable
memory_efficient_attention.fa3B@0.0.0:             unavailable
memory_efficient_attention.triton_splitKF:         available
indexing.scaled_index_addF:                        available
indexing.scaled_index_addB:                        available
indexing.index_select:                             available
sequence_parallel_fused.write_values:              available
sequence_parallel_fused.wait_values:               available
sequence_parallel_fused.cuda_memset_32b_async:     available
sp24.sparse24_sparsify_both_ways:                  available
sp24.sparse24_apply:                               available
sp24.sparse24_apply_dense_output:                  available
sp24._sparse24_gemm:                               available
sp24._cslt_sparse_mm@0.4.0:                        available
swiglu.dual_gemm_silu:                             available
swiglu.gemm_fused_operand_sum:                     available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
pytorch.version:                                   2.4.1+cu118
pytorch.cuda:                                      available
gpu.compute_capability:                            8.9
gpu.name:                                          NVIDIA GeForce RTX 4090
dcgm_profiler:                                     unavailable
build.info:                                        available
build.cuda_version:                                1108
build.hip_version:                                 None
build.python_version:                              3.9.20
build.torch_version:                               2.4.1+cu118
build.env.TORCH_CUDA_ARCH_LIST:                    6.0+PTX 7.0 7.5 8.0+PTX
build.env.PYTORCH_ROCM_ARCH:                       None
build.env.XFORMERS_BUILD_TYPE:                     Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              -allow-unsupported-compiler
build.env.XFORMERS_PACKAGE_FROM:                   wheel-v0.0.28.post1
build.nvcc_version:                                11.8.89
source.privacy:                                    open source

When I trained w/ a standard GPT-2 (~89M parameters) using scaled_dot_product_attention, I got ~9it/s, but only ~7it/s on memory_efficient_attention.

And I cannot train a GPT-2-medium (~300M parameters) when using memory_efficient_attention, but I can train that w/ scaled_dot_product_attention.

All exps are trained using fp16 and w/ torch.compile.

danthe3rd commented 1 week ago

Hi, memory_efficient_attention used to be faster than PyTorch's SDPA because xFormers was using Flash-Attention. Now SDPA is also using Flash-Attention, so it's normal to have the same speed. Also the dimensions need to be transposed between SDPA (format BHMK) and memory_efficient_attention (format BMHK).

QinlongHuang commented 1 week ago

Thank u so much for the quick reply!

So is there any gain to use memory_efficient_attention of xformers instead of PyTorch's SDPA now?

And it seems that memory_efficient_attention is not compatiable w/ torch.compile which can speed up training and use less memory.

Besides, you've mentioned that "the dimensions need to be transposed between SDPA (format BHMK) and memory_efficient_attention (format BMHK)". So I have to transpose the QKV to get the CORRECT results?

I did a toy test w/ the following snippet.

import torch
import torch.nn.functional as F
from xformers.ops import memory_efficient_attention

batch_size, num_heads, seq_length, head_dim = 32, 128, 512, 256
q = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda')  # BHMK
k = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda')
v = torch.randn(batch_size, num_heads, seq_length, head_dim, device='cuda')

# PyTorch scaled_dot_product_attention
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

for i in range(10):
    output_torch = F.scaled_dot_product_attention(q, k, v)

end.record()
torch.cuda.synchronize()
print(f"PyTorch scaled_dot_product_attention: {start.elapsed_time(end)} ms")

# xformers ScaledDotProduct
q = q.transpose(1, 2)  # BMHK
k = k.transpose(1, 2)
v = v.transpose(1, 2)
torch.cuda.synchronize()
start.record()

for i in range(10):
    output_xformers = memory_efficient_attention(q, k, v)

end.record()
torch.cuda.synchronize()
print(f"xformers memory_efficient_attention: {start.elapsed_time(end)} ms")

And now I get the similiar speed w/ these two implementations.