facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.65k stars 614 forks source link

`fmha.cutlass.FwOp` is 2x slower than `fmha.flash.FwOp` #1074

Open xenshinu opened 3 months ago

xenshinu commented 3 months ago

❓ Questions and Help

Hi, I did a simple profiling of xformers cutlass implementation of attention vs flash attn. I was assuming they are the same algorithm with different implementation on Ampere+, but the time consumption is very different. The cutlass op is almost 2x slower than flash op. Both of them are run on A40. image image The code is as below

seqlen_range = [512, 1024, 2048, 4096, 8192, 12288, 16384, 20480, 24576, 28672, 32768, 65536]
for seqlen in seqlen_range:
      # Generate random input for batch size and sequence length
      query = torch.randn(1, seqlen, 32, 128).to(device).half()
      key = torch.randn(1, seqlen, 32, 128).to(device).half()
      value = torch.randn(1, seqlen, 32, 128).to(device).half()

      for i in range(5):

          torch.cuda.synchronize()
          start = time.perf_counter()
          me_output = xformers.ops.memory_efficient_attention_forward(query, key, value, op=fmha.cutlass.FwOp)
          torch.cuda.synchronize()
          end = time.perf_counter()

          me_attn = end - start
          print(f"Time breakdown: \n"
              f"me_attention={me_attn:.4f} ms")

          if compute_capability[0] >= 8:
              torch.cuda.synchronize()
              start = time.perf_counter()
              fa_output = xformers.ops.memory_efficient_attention_forward(query, key, value, op=fmha.flash.FwOp)
              torch.cuda.synchronize()
              end = time.perf_counter()

              fa_attn = end - start
              print(f"Time breakdown: \n"
                  f"fa_attention={fa_attn:.4f} ms")

I also tested on different batch size, the result is similar. BTW, I think flash-attn is also written with CUTLASS, so what is the difference? (except it can run on Pascal+)

lw commented 3 months ago

I'm not an expert on this, but why do you assume that the two backends should perform exactly the same? I believe the reason we have multiple backends is precisely in order to pick the most performant one in each situation. FlashAttention has been continuously optimized so I'm not too surprised that it performs better.

xenshinu commented 3 months ago

Well I saw this comment https://github.com/facebookresearch/xformers/issues/950#issuecomment-1864793941 that the cutlass.op is only "a bit slower". Thus I was thinking they are the same algorithm with different impls. Thanks for your reply.

danthe3rd commented 3 months ago

They are indeed the same mathematic algorithm (in terms of mathematical operations), but the way work is parallelized and scheduled is a bit different. Plus the implementation details matter a lot when optimizing CUDA kernels :)