Closed hypnopump closed 10 months ago
Thank you for contribution!
Did you run the test and benchmark? And did you find some improvements with this change?
I used to add some fine-grained control about masking. Both by static and runtime conditions.
DIVISIBLE_M
and DIVISIBLE_N
. If we know that m, n is divisible by BLOCK_M
and BLOCK_N
, respectively, we can skip masking with triton's constexpr as condition. This removes masking at compile time, without runtime cost.DIVISIBLE_M
and DIVISIBLE_N
is true, only the last tile requires masking, both for loading, and some other computation. I find that static conditioning on masking improves the performances while runtime conditioning does not.(In my earlier implementation, I check the cta id to check whether this is the last cta in m dimention, and also the iteration id in the lopp to check whether it is the last tile in n dimension. However, this implementation reduces the performance.)
if tl.program_id(0) is the last one:
apply msking in m dimension
for i in loop:
if i is the last iteration:
apply masking in n dimension
We can also think about the simple pointwise computation kernel in CUDA.
__global__
void saxpy(int n, float a, float *x, float *y)
{
int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i < n) y[i] = a*x[i] + y[i];
}
Although we can make finer control about masking, say, "only the last CTA should apply maksing", or "only the last warp should apply maksing". But to determine whether the thread is in the last CTA or last warp requires some extra computation at runtime. So this may not improve performance. I need some performance data to ensure that this would help.
As to the style, I prefer the dummy layout that uses different blocks in if-else statement
s instead of if-else expression
s when the condition is a tl.constexpr
to remind me that static condition (conditions on tl.constexpr
) is like template parameters or Macros #if #else #endif
in c++. It is compile-time-effective rather than runtime-effective. It may look ugly, but it leaves some visual traces that it is not a usual if-else.
Also I found some bugs in triton compiler that it may not treat static conditions or static-and-dynamic conditions correctly.
For example:
A
and B
are both tl.constexpr
, but triton would not necessarily treat A and not B
as a static condition.A
is a tl.constexpr
while b
is not. Triton compiler cannot correctly perform logical shortcut like A and not B
when compiling.Yes I think you're right, triton really does not like conditionals evaluated at execution... will keep looking into this.
I've tested your commit on A100 with divisible_m
and divisible_n
, which is included in the benchmark. The performance is about the same as the current implemention.
But when tesing with non-divisible m and non divisible_n, the code fails to compile. You can reproduce this by add a N_CTX -= 10
into https://github.com/FlagOpen/FlagAttention/blob/53ce09c40fc23a2b45a4b7e42b982e77c086afb3/benchmark/flash_benchmark.py#L49 .
Maybe I should add more unittesting cases to ensure it runs with non divisible m or divisible n.
ah yes im closing this, seems is triton tries to analyze things inside dynamic if-elses; only way to have a bool value the compiler "respects" seems to be passing it to the whole kernel grid, as DIVISIBLE_M
, DIVISIBLE_N
Skips mask application is block is not the boundary block.