FlagOpen / FlagAttention

A collection of memory efficient attention operators implemented in the Triton language.
Other
213 stars 13 forks source link

Flash Attention 2: Save mask application if block is not in boundary #3

Closed hypnopump closed 10 months ago

hypnopump commented 10 months ago

Skips mask application is block is not the boundary block.

iclementine commented 10 months ago

Thank you for contribution!

Did you run the test and benchmark? And did you find some improvements with this change?

I used to add some fine-grained control about masking. Both by static and runtime conditions.

I find that static conditioning on masking improves the performances while runtime conditioning does not.(In my earlier implementation, I check the cta id to check whether this is the last cta in m dimention, and also the iteration id in the lopp to check whether it is the last tile in n dimension. However, this implementation reduces the performance.)

if tl.program_id(0) is the last one:
    apply msking in m dimension

for i in loop:
    if i is the last iteration:
        apply masking in n dimension
图片

We can also think about the simple pointwise computation kernel in CUDA.

__global__
void saxpy(int n, float a, float *x, float *y)
{
  int i = blockIdx.x*blockDim.x + threadIdx.x;
  if (i < n) y[i] = a*x[i] + y[i];
}

Although we can make finer control about masking, say, "only the last CTA should apply maksing", or "only the last warp should apply maksing". But to determine whether the thread is in the last CTA or last warp requires some extra computation at runtime. So this may not improve performance. I need some performance data to ensure that this would help.

iclementine commented 10 months ago

As to the style, I prefer the dummy layout that uses different blocks in if-else statements instead of if-else expressions when the condition is a tl.constexpr to remind me that static condition (conditions on tl.constexpr) is like template parameters or Macros #if #else #endif in c++. It is compile-time-effective rather than runtime-effective. It may look ugly, but it leaves some visual traces that it is not a usual if-else.

Also I found some bugs in triton compiler that it may not treat static conditions or static-and-dynamic conditions correctly.

For example:

  1. A and B are both tl.constexpr, but triton would not necessarily treat A and not B as a static condition.
  2. A is a tl.constexpr while b is not. Triton compiler cannot correctly perform logical shortcut like A and not B when compiling.
hypnopump commented 10 months ago

Yes I think you're right, triton really does not like conditionals evaluated at execution... will keep looking into this.

iclementine commented 10 months ago

I've tested your commit on A100 with divisible_m and divisible_n, which is included in the benchmark. The performance is about the same as the current implemention.

But when tesing with non-divisible m and non divisible_n, the code fails to compile. You can reproduce this by add a N_CTX -= 10 into https://github.com/FlagOpen/FlagAttention/blob/53ce09c40fc23a2b45a4b7e42b982e77c086afb3/benchmark/flash_benchmark.py#L49 .

图片

Maybe I should add more unittesting cases to ensure it runs with non divisible m or divisible n.

hypnopump commented 10 months ago

ah yes im closing this, seems is triton tries to analyze things inside dynamic if-elses; only way to have a bool value the compiler "respects" seems to be passing it to the whole kernel grid, as DIVISIBLE_M, DIVISIBLE_N