Closed chengjunlu closed 2 weeks ago
The tasks to support this of the insights for PVC now:
tt.dot
ops with prefetching.tt.dot
operation. So that we can use packed 2D load/store to load the operands A and B.
-- #969I will create the sub issues to track the progress of each tasks individually.
@chengjunlu, there is a similar issue #773 for improving GEMM and flash attention performance, could you please provide some examples which #773 cannot handle and go to this path?
@chengjunlu, there is a similar issue #773 for improving GEMM and flash attention performance, could you please provide some examples which #773 cannot handle and go to this path?
Based on the discussion and plan, the #773 is focus on the Triton kernel with the block pointer. This issue is to track the Triton kernel of variant long tail cases besides the block pointers.
As there is a trend in Triton community is going to obsolete the block pointer, May need to high priority of the task: -- ToP (tensor of pointers): https://github.com/intel/intel-xpu-backend-for-triton/issues/880
The flash attention optimization notes. The most passes are aligned with the GEMM optimization. Here is some memo specific for the flash attention optimization:
%28 = tt.load %27 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%29 = tt.splat %16 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.extf %28 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%31 = arith.mulf %30, %29 : tensor<128x16xf32, #blocked> loc(#loc19)
%32 = arith.truncf %31 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%33:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %8) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>, !tt.ptr<tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>) : i32 {
%49 = tt.load %arg26 : !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>> loc(#loc22)
%50 = tt.load %arg27 : !tt.ptr<tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>> loc(#loc23)
%51 = triton_gpu.convert_layout %32 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)
%52 = tt.dot %51, %49, %cst, inputPrecision = tf32 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x64xf32, #mma> loc(#loc24)
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 8, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], A = [8, 16], B = [16, 8], C = [8, 8]}>
...
%65 = triton_gpu.convert_layout %64 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc36)
Helping with refactoring the PR.
Update the status for the flash attention performance on the fallback path from the issue https://github.com/intel/intel-xpu-backend-for-triton/issues/878.
The flash attention optimization notes. The most passes are aligned with the GEMM optimization. Here is some memo specific for the flash attention optimization:
Already finished:
Things need to be finished:
Update the status for the GEMM performance on the fallback path from the issue #878.
Already finished:
tt.dot
ops with prefetching.Things need to be finished:
tt.dot
operation. So that we can use packed 2D load/store to load the operands A and B.
-- #969@etiotto I'd like to close this issue as all the tasks have been finished.
And the new changes and tasks have been tracked in the new issue: https://github.com/intel/intel-xpu-backend-for-triton/issues/2177
Any concern to close this issue as finished?
This issue is to track the tasks to improve the performance of the Triton GEMM/Flash attention kernel performance with the default Triton passes pipeline to reach a reasonable number. (~80% to XeTLA kernel).
There are a lot of different variant implementation of the Triton kernels for the in-variant algorithm GEMM/Flash attention. Two basic reasones that we need to enhance the kernel performance with the default Triton passes pipelining:
It is important to support those long tail variants Triton kernel to get a reasonable performance.