intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
129 stars 38 forks source link

[Performance] Enhance the Triton GEMM/Flash attention kernel performance for the default Triton passes pipeline #878

Closed chengjunlu closed 2 weeks ago

chengjunlu commented 5 months ago

This issue is to track the tasks to improve the performance of the Triton GEMM/Flash attention kernel performance with the default Triton passes pipeline to reach a reasonable number. (~80% to XeTLA kernel).

There are a lot of different variant implementation of the Triton kernels for the in-variant algorithm GEMM/Flash attention. Two basic reasones that we need to enhance the kernel performance with the default Triton passes pipelining:

  1. The user may use their preferred Triton syntax and way to implement the kernel which is not fit to our first class solution. (Typical Triton is obsoleting the BlockPointer.)
  2. Some Triton ops doesn't support Block Pointer which is required by the Triton kernel. Like Atomic ops. (K-dim parallel reducing GEMM, FlashAttension V3 for K-dim parallel online softmax.)

It is important to support those long tail variants Triton kernel to get a reasonable performance.

chengjunlu commented 5 months ago

The tasks to support this of the insights for PVC now:

I will create the sub issues to track the progress of each tasks individually.

tdeng5 commented 5 months ago

@chengjunlu, there is a similar issue #773 for improving GEMM and flash attention performance, could you please provide some examples which #773 cannot handle and go to this path?

chengjunlu commented 5 months ago

@chengjunlu, there is a similar issue #773 for improving GEMM and flash attention performance, could you please provide some examples which #773 cannot handle and go to this path?

Based on the discussion and plan, the #773 is focus on the Triton kernel with the block pointer. This issue is to track the Triton kernel of variant long tail cases besides the block pointers.

chengjunlu commented 5 months ago

As there is a trend in Triton community is going to obsolete the block pointer, May need to high priority of the task: -- ToP (tensor of pointers): https://github.com/intel/intel-xpu-backend-for-triton/issues/880

chengjunlu commented 5 months ago

The flash attention optimization notes. The most passes are aligned with the GEMM optimization. Here is some memo specific for the flash attention optimization:

%28 = tt.load %27 : tensor<128x16x!tt.ptr<f16>, #blocked> loc(#loc18)
%29 = tt.splat %16 : f32 -> tensor<128x16xf32, #blocked> loc(#loc19)
%30 = arith.extf %28 : tensor<128x16xf16, #blocked> to tensor<128x16xf32, #blocked> loc(#loc19)
%31 = arith.mulf %30, %29 : tensor<128x16xf32, #blocked> loc(#loc19)
%32 = arith.truncf %31 : tensor<128x16xf32, #blocked> to tensor<128x16xf16, #blocked> loc(#loc20)
%33:5 = scf.for %arg22 = %c0_i32 to %arg20 step %c64_i32 iter_args(%arg23 = %cst_3, %arg24 = %cst_2, %arg25 = %cst_1, %arg26 = %6, %arg27 = %8) -> (tensor<128x16xf32, #mma>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>>, !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>, !tt.ptr<tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>>)  : i32 {
      %49 = tt.load %arg26 : !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>> loc(#loc22)
      %50 = tt.load %arg27 : !tt.ptr<tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>> loc(#loc23)
      %51 = triton_gpu.convert_layout %32 : tensor<128x16xf16, #blocked> -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc20)
      %52 = tt.dot %51, %49, %cst, inputPrecision = tf32 : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x64xf32, #mma> loc(#loc24)
#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 8, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], A = [8, 16], B = [16, 8], C = [8, 8]}>
...
%65 = triton_gpu.convert_layout %64 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> loc(#loc36)
etiotto commented 5 months ago

Helping with refactoring the PR.

chengjunlu commented 3 months ago

Update the status for the flash attention performance on the fallback path from the issue https://github.com/intel/intel-xpu-backend-for-triton/issues/878.

The flash attention optimization notes. The most passes are aligned with the GEMM optimization. Here is some memo specific for the flash attention optimization:

Already finished:

Things need to be finished:

chengjunlu commented 3 months ago

Update the status for the GEMM performance on the fallback path from the issue #878.

Already finished:

Things need to be finished:

chengjunlu commented 3 weeks ago

@etiotto I'd like to close this issue as all the tasks have been finished.

And the new changes and tasks have been tracked in the new issue: https://github.com/intel/intel-xpu-backend-for-triton/issues/2177

Any concern to close this issue as finished?

etiotto commented 2 weeks ago

950 is still open but I think we can track that one separately in #2177. So yes.