Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.47k stars 1.36k forks source link

Port to (shifted) WindowAttention? #14

Open void-main opened 2 years ago

void-main commented 2 years ago

Hey authors, great repo for boosting training of Attention-based models. I wonder how the code can be ported to support (shifted)WindowAttention?

To my knowledge, the (S)WindowAttention differs from traditional Attention on:

  1. SWAttention has a relative position bias term inside softmax: Softmax(QK^T/sqrt(dim) + Bias)V^T;
  2. The mask pattern is different;
  3. The head dims are different;

According to this difference, here are several code that I found out that should be changed:

  1. Add a new Tensor to store the Bias term for SWAttention in fwd pass, and use dSoftmax as dBias for bwd pass; besides we need corresponding iterator / gmem, smem loaders;
  2. Add a new Tensor to store the mask buffer in fwd pass;
  3. As for head dims, why the launch_params.paramd.d is limited to 16, 32, 64, 128, could we extend this to other dim sizes?

Could you please offer some guidance on how to port to WindowAttention? Thanks.

tridao commented 2 years ago

Thanks for your interest! Does shifted window attention refer to the attention used in SwinTransformer? I haven't profiled it so I'd first figure out where the bottlenecks are.

  1. You're right, we'd need the bias term in the fwd, and dBias for backward. The dSoftmax is already computed but never stored to HBM as we want to reduce HBM reads/writes for performance.
  2. I'm not familiar with what the masks look like. Depending on their sizes, you might want to reshape things (e.g. SwinTransformer) before call attention, instead of calling attention with a mask.
  3. We currently support head dimension 16, 32, 64, 128 as you mentioned. This is because (a) we use tensor cores for matrix multiply so the dimensions need to be divisible by 16 (b) for maximum performance, we customize the smem loaders for each head dimension. Ideally I'd want to make the code more flexible and support more dimensions. I'm looking to port some part of it to use Cutlass which will hopefully take care of the smem reading/writing.
void-main commented 2 years ago

Hey @tridao, thanks for your reply!

I have a few follow up questions (sorry for so many questions):

  1. You're right, we'd need the bias term in the fwd, and dBias for backward. The dSoftmax is already computed but never stored to HBM as we want to reduce HBM reads/writes for performance.

Since the code is a little bit hard to follow (in a good way), could you please give me some hints on how to add dBias to forward pass? To be exact, I believe that I should add the bias after this line, however, how could I calculate the correct offset of the bias?what's the semantics of Mma_tile_p and Mma_tile_o? should I use Mma_tile_o?

As of the dSoftmax, do you think it's a good timing here to copy dSoftmax to dBias?

Also, I vaguely felt that there are some patterns of the code (gmem_tile, smem_tile, mma_tile), but it's kind of hard to connect these dots. Could you please give me an example of how data flows through these segments during forward pass of the QK^T pass, so that I could try to understand the code myself.

Big Thanks!