Open chengjunlu opened 6 months ago
The Q is loaded once out of the SCF body. But the convert layout op inside the SCF loop body is in-variant and should be moved out of the SCF body. The TritonGPUReorderInstructions
can optimize this.
A potential issue in the TTGIR might be aggressively using the register value. The Q is always resident in register value through the whole SCF body. We can use the SLM as the scratch space to keep the Q instead of in register. This has already been achieved in Triton passes pipeline.
The convert layout operation can be optimized by public Triton passes TritonGPUReduceDataDuplication
and TritonGPUReorderInstructions
.
The original input:
%26 = tt.load %25 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%27 = tt.splat %14 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%28 = arith.extf %26 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%29 = arith.mulf %28, %27 : tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.truncf %29 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%31:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %5, %arg27 = %5) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64) : i32 {
...
%94 = triton_gpu.convert_layout %30 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)
After TritonGPUReduceDataDuplication
:
%26 = tt.load %25 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%27 = tt.splat %14 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%28 = arith.extf %26 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%29 = arith.mulf %28, %27 : tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.truncf %29 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%31:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %5, %arg27 = %5) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64) : i32 {
...
%96 = triton_gpu.local_alloc %30 : (tensor<128x16xf16, #blocked>) -> !tt.memdesc<128x16xf16, #shared1> loc(#loc20)
%97 = triton_gpu.local_load %96 : !tt.memdesc<128x16xf16, #shared1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)
After TritonGPUReorderInstructions
:
%26 = tt.load %25 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%27 = tt.splat %14 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%28 = arith.extf %26 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%29 = arith.mulf %28, %27 : tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.truncf %29 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%31 = triton_gpu.local_alloc %30 : (tensor<128x16xf16, #blocked>) -> !tt.memdesc<128x16xf16, #shared> loc(#loc20)
%32:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %5, %arg27 = %5) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64) : i32 {
...
%95 = triton_gpu.local_load %31 : !tt.memdesc<128x16xf16, #shared> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)
@whitneywhtsang , @etiotto . Please help to check the existing mechanism in Triton whether it is good enough for now. And we can close this issue as nothing to do for now.
In principle what you say make sense because (a) the TritonGPUReduceDataDuplication
transformation has replaced the convert layout operation with a local_alloc + local_load operations and (b) the TritonGPUReorderInstructions
transformation was able to hoist the local_Alloc operation outside the loop. The loop then contain a local_load which loads data from shared memory to registers, but the registers live range will be smaller (not the entire loop anymore).
I was trying to reproduce the behavior in https://github.com/intel/intel-xpu-backend-for-triton/issues/950#issuecomment-2109052425.
What is the benchmark you used ? I tried 06-fused-attention.py
and could not see the pattern you show above. Did you use any special branch ?
I was trying to reproduce the behavior in #950 (comment). What is the benchmark you used ? I tried
06-fused-attention.py
and could not see the pattern you show above. Did you use any special branch ?
I am using my branch. But I think those passes should be same to the llvm-target.
I am running the flash attention forward kernel under the name space triton.ops.flash_attention
.
I will double confirm the forward kernel 06-fused-attention.py
.
The flash attention kernel of the 06-fused-attention.py
is different.
The DotOp layout has been backward propagated to the tt.load
, tt.make_tensor_ptr
for the Q matrix.
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [32, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
...
%15 = tt.make_tensor_ptr %11, [%13, %c64_i64], [%14, %c1_i64], [%12, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>> loc(#loc13)
%31 = tt.load %15 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>> loc(#loc23)
%35:5 = scf.for %arg20 = %c0_i32 to %12 step %c64_i32 iter_args(%arg21 = %cst_1, %arg22 = %cst, %arg23 = %cst_0, %arg24 = %c0_i64, %arg25 = %32) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x64xf32, #mma>, tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>) : i32 {
...
%90 = tt.dot %31, %89, %cst, inputPrecision = tf32 : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<64x64xf32, #mma> loc(#loc85)
The TritonGPUReduceDataDuplication
doesn't work on the value %31
.
For this case, it seems no register pressure. Because the Q is A operand layout and the tiling is warpsPerCTA = [32, 1]
. That means the value of the Q is not duplicated across the different warps and threads.
The key difference of the two flash attention implementation to Intel is that:
06-fused-attetion.py
uses the block pointer for Q. https://github.com/intel/intel-xpu-backend-for-triton/blob/e84f153c05da274162c08c2cbefa3d27e4780296/python/tutorials/06-fused-attention.py#L122triton.ops.flash_attention
uses the pointer athematic for Q. https://github.com/intel/intel-xpu-backend-for-triton/blob/e84f153c05da274162c08c2cbefa3d27e4780296/python/triton/ops/flash_attention.py#L72Reopen this issue as the spilling Q only works for the tensor of points.
We need to also enable it in the 06-fused-attetion.py
Moved to blocked status. We'll resume work after new Agama is released.
To spill the Q matrix is not needed for FA kernel on default pass now. All the values of the P, S and K can be kept in registers of double GRF.
We can lower the priority of this issue in this iteration 18 and to close this issue as not required in iteration 19.
We can remove the un-necessary convert layout for the Q matrix. The dot layout can be backward combine all the way back to the ops when load the Q operands. As show in the pieces of the MLIR.
The layout has been successfully backward combined for K (%49) and V (%50). But not Q (%32->%51). Need to improve that.