intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
141 stars 43 forks source link

Investigate reduction performance difference between XeTLA and Triton #1637

Closed victor-eds closed 1 month ago

victor-eds commented 3 months ago

Looking into our FlashAttention-2 benchmark, we see that reduction codegen for XeTLA is better than Triton.

Investigate reasons and whether the kernel code or any pass in our pipeline has to be tuned.

victor-eds commented 3 months ago

Initial investigation. For perf_attn:06-fused-attention.forward.py.

Environment:

Commit: 4313cfee5ad7a4ff0b267be8c0775937227c25ec
Device: Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1100
L0 version: 1.3.29735
IGC version: igc-1.0.16900.23

Results:

     Z     H   N_CTX  D_HEAD     Triton     oneDNN
0  4.0  48.0  1024.0    64.0  20.311264  40.937938

If we look into the generated assembly, we see:

--- Triton
+++ XeTLA
-//.spill size 10368
-//.spill GRF est. ref count 1935
+//.spill size 7040
+//.spill GRF est. ref count 343

Triton spills way more variables than XeTLA.

Also:

--- Triton
+++ XeTLA
+      slm_size:        25600

XeTLA's kernel is using local memory to cache data whereas our kernel's only using global and private memory. This may explain worse spilling on our side.

victor-eds commented 3 months ago

Memory access analyses (Note different Y axis scales):

XeTLA: image

Triton: image

Note how the Triton kernel performs way more memory accesses. This might be caused by worse spilling.

victor-eds commented 3 months ago

Utilization analysis (Note different Y axis scale):

XeTLA: image

Triton:

image

This translates to worse utilization in our kernel. Note Triton stall % is about as double and utilization, about as half.

victor-eds commented 3 months ago

Looking at a stall event chart for these two kernels: XeTLA: image

Triton: image

The purple line corresponds with SBID (score board dependency) stalls. These stall events are related to memory accesses. We see how XeTLA's are more grouped and a handful of instructions dominate the graph. However, for Triton, we see way more purple all around the graph. This would be yet another hint at spilling hurting our FlashAttention-2 performance.

victor-eds commented 3 months ago

Using the options provided in collect.sh, we get the following data when running FlashAttention-2:

First of all, it's worth noting that increasing the GRF size fixes the stalling problem and no stalling is taking place now. This makes our kernel "better behaved" w.r.t. memory accesses:

image

Stalling-wise, we have also improved a lot, getting <40 % stalled threads. It's also true our occupancy is around 50 %. This is the case as we're increasing the GRF size and our GPU cannot host as much threads.

image

If we take a look at the stalling events:

image

We see the effects of getting rid of spilling. However, we see a couple of instructions dominating the stalling. Pending investigation. Also, we see more "other stall" events.

victor-eds commented 3 months ago

Looking into reduction differences and whether reduced occupancy may be detrimental for us.

victor-eds commented 3 months ago

Our reduction code is similar to (pseudocode):

%input0 : tensor<8x16xf32>
%input1 : tensor<8x16xf32>
; 8 slices for each tensor -> 16 slices
%slice_0 = extract %input0[0] : tensor<8x16xf32> -> tensor<16xf32>
%max_0 = reduce { max } (%slice_0) -> f32
...
%slice_7 = extract %input0[7] : tensor<8x16xf32> -> tensor<16xf32>
%max_7 = reduce { max } (%slice_7) -> f32
...
%slice_8 = extract %input1[0] : tensor<8x16xf32> -> tensor<16xf32>
%max_8 = reduce { max } (%slice_8) -> f32
%red_res = glue %max_0 %max_1 ... : tensor<16xf32>

In LLVM, this would be:

%input0 : <8 x float>
%input1 : <8 x float>
; 8 slices for each tensor
%el_0 = extractelement <8 x float> %input0, i64 0
%max_0 = call float @llvm.GenISA.WaveAll.f32(float %el_0, i8 12, i32 0)
; ...
%red_res = ; insertelement x 16

The issue with WaveAll is that it compiles to (SIMD16, pseudocode):

; Split input 16-element vector in two. '.8' means 8 elements offset
mov (8|M0) rhs lhs.8
; Element-wise max, store in lhs
sel (8|M0) ge lhs lhs rhs
; Do the same again and again
mov (4|M0) rhs lhs.4
sel (4|M0) lhs lhs rhs
mov (2|M0) rhs lhs.2
sel (2|M0) lhs lhs rhs
mov (1|M0) lhs rhs.1
sel (1|M0) lhs lhs rhs

This results in 16 reductions of 8 operations each, meaning 2^7 operations. This is the case because we are not exploiting SIMD parallelism here, we're just reducing over the subgroup, underutilizing SIMD lines.

Ideally, we would like to do something like:

; Glued and transposed input matrices
%input : tensor<16x16xf32>
; 16 slices
%slice_0 = extract %input[0] : tensor<16x16xf32> -> tensor<16xf32>
...
; 15 ops to reduce each slice
%max_0 = max %slice_0, %slice_1
%max_1 = max %max_0, %slice_2
...
%max_t = glue %max_0, %max_1 -> tensor<16x1xf32>

Which, in LLVM, would be as simple as:

; Transpose original input
%input: <16 x float>
%el_0 = extractelement %input[0]
%el_1 = extractelement %input[1]
%max_0 = max %el_0, %el_1
; Just 15 max operations required!
; ...
%max = max %max_n1 %max_n2
; Transpose output

Here we could perform the same reduction in just 15 operations. Now, why does this "matrix transpose" work? Where is even the matrix? How can we transpose it?

Let's step back for a moment. Conceptually, we have a 16x16 matrix here distributed as follows across the subgroup:

+--------------------+---------+
| sub_group_local_id | indices |
+--------------------+---------+
|                  0 | 0-15    |
|                  1 | 16-31   |
|                  2 | 32-47   |
|                ... | ...     |
+--------------------+---------+

In this matrix, reducing across the second dimension is very expensive, as we need to perform subgroup reduction. Doing a matrix transpose:

+--------------------+---------------------+
| sub_group_local_id |       indices       |
+--------------------+---------------------+
|                  0 | 0,16,...,i*16,...   |
|                  1 | 1,17,...,i*16+1,... |
|                  2 | 2,18,...,i*16+2,... |
|                ... | ...                 |
+--------------------+---------------------+

Now we can reduce over the first dimension utilizing the SIMD lanes as we saw above.

We've seen the matrix and how nice it is to transpose it. However, how can we do that? This is yet to be explored. I'll build a POC for this. Possible solutions include going through local memory (probably fastest) or shufflevector + sub_group_shuffle to do everything in registers.

chengjunlu commented 3 months ago

By transposing it has extra overhead.

The IGC supports to mix the SIMD and SIMT expression if there was limitation in SIMT expression (and No IGC specific intrinsic to express the computation clearly. E.G: WaveAll take the vector type as the input.).

The invoke SIMD and inline VISA can help to generate micro operation with the SIMD extension.

And in this case, I think the invoke SIMD is more suitable.

https://github.com/intel/llvm/blob/3c2c938769914eedd935c4fff097591f340ee0de/sycl/doc/extensions/experimental/sycl_ext_oneapi_invoke_simd.asciidoc

And the SIMD 2D extension allows us to defines the strides of the operands of the SIMD binary operation. It can allow us to use the best parallelism of the SIMD.

https://github.com/intel/vc-intrinsics/blob/master/GenXIntrinsics/docs/GenXLangRef.rst#2d-region

victor-eds commented 3 months ago

The GENX intrinsics may be a nice option. However, our kernels model work-item behavior as of now, not EU behavior, so this would mean changing pretty much all we have as of now, right? I'd rather go with the local memory approach for now, get some numbers and see if that improves performance.

chengjunlu commented 3 months ago

Yes, make sense. Let's check out what we can get with pure SIMT paradigm.

victor-eds commented 3 months ago

Prototype with SLM transpose taking shape. Overall structure finalized, missing last bits of lowering.

victor-eds commented 2 months ago

Pushed prototype. Shaping into what we want.

victor-eds commented 2 months ago

Couldn't work on POC much this week. DPAS output input to reduction has expected size now to exploit matrix transpose approach

victor-eds commented 2 months ago

Lots of progress in POC. Hitting one workaround (hopefully last one). Will try to work around it.

victor-eds commented 1 month ago

PoC ready to review. Will push latest perf report. Minor fixes to be pushed.

victor-eds commented 1 month ago

PR with latest needed code in llvm-target ready for review and with one approval (#2109 ). Feature will be available behind an environment variable.