intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
143 stars 44 forks source link

[Performance] Spill the Q matrix to SLM to reduce the register pressure in flash attention. #950

Open chengjunlu opened 6 months ago

chengjunlu commented 6 months ago

We can remove the un-necessary convert layout for the Q matrix. The dot layout can be backward combine all the way back to the ops when load the Q operands. As show in the pieces of the MLIR.

%28 = tt.load %27 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%29 = tt.splat %16 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.extf %28 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%31 = arith.mulf %30, %29 : tensor<128x16xf32, #blocked> loc(#loc19)
%32 = arith.truncf %31 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%33:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %8) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>, !tt.ptr<tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>)  : i32 {
      %49 = tt.load %arg26 : !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>> loc(#loc22)
      %50 = tt.load %arg27 : !tt.ptr<tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>> loc(#loc23)
      %51 = triton_gpu.convert_layout %32 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)
      %52 = tt.dot %51, %49, %cst, inputPrecision = tf32 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x64xf32, #mma> loc(#loc24)
...
      %66 = tt.dot %65, %50, %63, inputPrecision = tf32 : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x16xf32, #mma> loc(#loc37)

The layout has been successfully backward combined for K (%49) and V (%50). But not Q (%32->%51). Need to improve that.

chengjunlu commented 6 months ago

The Q is loaded once out of the SCF body. But the convert layout op inside the SCF loop body is in-variant and should be moved out of the SCF body. The TritonGPUReorderInstructions can optimize this. A potential issue in the TTGIR might be aggressively using the register value. The Q is always resident in register value through the whole SCF body. We can use the SLM as the scratch space to keep the Q instead of in register. This has already been achieved in Triton passes pipeline.

The convert layout operation can be optimized by public Triton passes TritonGPUReduceDataDuplication and TritonGPUReorderInstructions.

The original input:

%26 = tt.load %25 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%27 = tt.splat %14 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%28 = arith.extf %26 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%29 = arith.mulf %28, %27 : tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.truncf %29 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%31:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %5, %arg27 = %5) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64)  : i32 {
      ...
      %94 = triton_gpu.convert_layout %30 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)

After TritonGPUReduceDataDuplication:

%26 = tt.load %25 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%27 = tt.splat %14 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%28 = arith.extf %26 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%29 = arith.mulf %28, %27 : tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.truncf %29 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%31:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %5, %arg27 = %5) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64)  : i32 {
      ...
      %96 = triton_gpu.local_alloc %30 : (tensor<128x16xf16, #blocked>) -> !tt.memdesc<128x16xf16, #shared1> loc(#loc20)
      %97 = triton_gpu.local_load %96 : !tt.memdesc<128x16xf16, #shared1> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)

After TritonGPUReorderInstructions:

%26 = tt.load %25 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%27 = tt.splat %14 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%28 = arith.extf %26 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%29 = arith.mulf %28, %27 : tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.truncf %29 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%31 = triton_gpu.local_alloc %30 : (tensor<128x16xf16, #blocked>) -> !tt.memdesc<128x16xf16, #shared> loc(#loc20)
%32:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %5, %arg27 = %5) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, i64)  : i32 {
      ...
      %95 = triton_gpu.local_load %31 : !tt.memdesc<128x16xf16, #shared> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)

@whitneywhtsang , @etiotto . Please help to check the existing mechanism in Triton whether it is good enough for now. And we can close this issue as nothing to do for now.

etiotto commented 6 months ago

In principle what you say make sense because (a) the TritonGPUReduceDataDuplication transformation has replaced the convert layout operation with a local_alloc + local_load operations and (b) the TritonGPUReorderInstructions transformation was able to hoist the local_Alloc operation outside the loop. The loop then contain a local_load which loads data from shared memory to registers, but the registers live range will be smaller (not the entire loop anymore).

I was trying to reproduce the behavior in https://github.com/intel/intel-xpu-backend-for-triton/issues/950#issuecomment-2109052425. What is the benchmark you used ? I tried 06-fused-attention.py and could not see the pattern you show above. Did you use any special branch ?

chengjunlu commented 6 months ago

I was trying to reproduce the behavior in #950 (comment). What is the benchmark you used ? I tried 06-fused-attention.py and could not see the pattern you show above. Did you use any special branch ?

I am using my branch. But I think those passes should be same to the llvm-target.

I am running the flash attention forward kernel under the name space triton.ops.flash_attention.

I will double confirm the forward kernel 06-fused-attention.py.

chengjunlu commented 6 months ago

The flash attention kernel of the 06-fused-attention.py is different.

The DotOp layout has been backward propagated to the tt.load, tt.make_tensor_ptr for the Q matrix.

#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [32, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
...
%15 = tt.make_tensor_ptr %11, [%13, %c64_i64], [%14, %c1_i64], [%12, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>> loc(#loc13)
%31 = tt.load %15 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>> loc(#loc23)
%35:5 = scf.for %arg20 = %c0_i32 to %12 step %c64_i32 iter_args(%arg21 = %cst_1, %arg22 = %cst, %arg23 = %cst_0, %arg24 = %c0_i64, %arg25 = %32) -> (tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<64x64xf32, #mma>, tensor<64xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, i64, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>)  : i32 {
    ...
    %90 = tt.dot %31, %89, %cst, inputPrecision = tf32 : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<64x64xf32, #mma> loc(#loc85)

The TritonGPUReduceDataDuplication doesn't work on the value %31. For this case, it seems no register pressure. Because the Q is A operand layout and the tiling is warpsPerCTA = [32, 1]. That means the value of the Q is not duplicated across the different warps and threads.

chengjunlu commented 6 months ago

The key difference of the two flash attention implementation to Intel is that:

  1. 06-fused-attetion.py uses the block pointer for Q. https://github.com/intel/intel-xpu-backend-for-triton/blob/e84f153c05da274162c08c2cbefa3d27e4780296/python/tutorials/06-fused-attention.py#L122
  2. triton.ops.flash_attention uses the pointer athematic for Q. https://github.com/intel/intel-xpu-backend-for-triton/blob/e84f153c05da274162c08c2cbefa3d27e4780296/python/triton/ops/flash_attention.py#L72
chengjunlu commented 3 months ago

Reopen this issue as the spilling Q only works for the tensor of points. We need to also enable it in the 06-fused-attetion.py

vlad-penkin commented 2 months ago

Moved to blocked status. We'll resume work after new Agama is released.

chengjunlu commented 1 week ago

To spill the Q matrix is not needed for FA kernel on default pass now. All the values of the P, S and K can be kept in registers of double GRF.

We can lower the priority of this issue in this iteration 18 and to close this issue as not required in iteration 19.