Closed victor-eds closed 1 month ago
Initial investigation. For perf_attn:06-fused-attention.forward.py
.
Environment:
Commit: 4313cfee5ad7a4ff0b267be8c0775937227c25ec
Device: Intel(R) Level-Zero, Intel(R) Data Center GPU Max 1100
L0 version: 1.3.29735
IGC version: igc-1.0.16900.23
Results:
Z H N_CTX D_HEAD Triton oneDNN
0 4.0 48.0 1024.0 64.0 20.311264 40.937938
If we look into the generated assembly, we see:
--- Triton
+++ XeTLA
-//.spill size 10368
-//.spill GRF est. ref count 1935
+//.spill size 7040
+//.spill GRF est. ref count 343
Triton spills way more variables than XeTLA.
Also:
--- Triton
+++ XeTLA
+ slm_size: 25600
XeTLA's kernel is using local memory to cache data whereas our kernel's only using global and private memory. This may explain worse spilling on our side.
Memory access analyses (Note different Y axis scales):
XeTLA:
Triton:
Note how the Triton kernel performs way more memory accesses. This might be caused by worse spilling.
Utilization analysis (Note different Y axis scale):
XeTLA:
Triton:
This translates to worse utilization in our kernel. Note Triton stall % is about as double and utilization, about as half.
Looking at a stall event chart for these two kernels: XeTLA:
Triton:
The purple line corresponds with SBID (score board dependency) stalls. These stall events are related to memory accesses. We see how XeTLA's are more grouped and a handful of instructions dominate the graph. However, for Triton, we see way more purple all around the graph. This would be yet another hint at spilling hurting our FlashAttention-2 performance.
Using the options provided in collect.sh
, we get the following data when running FlashAttention-2:
First of all, it's worth noting that increasing the GRF size fixes the stalling problem and no stalling is taking place now. This makes our kernel "better behaved" w.r.t. memory accesses:
Stalling-wise, we have also improved a lot, getting <40 % stalled threads. It's also true our occupancy is around 50 %. This is the case as we're increasing the GRF size and our GPU cannot host as much threads.
If we take a look at the stalling events:
We see the effects of getting rid of spilling. However, we see a couple of instructions dominating the stalling. Pending investigation. Also, we see more "other stall" events.
Looking into reduction differences and whether reduced occupancy may be detrimental for us.
Our reduction code is similar to (pseudocode):
%input0 : tensor<8x16xf32>
%input1 : tensor<8x16xf32>
; 8 slices for each tensor -> 16 slices
%slice_0 = extract %input0[0] : tensor<8x16xf32> -> tensor<16xf32>
%max_0 = reduce { max } (%slice_0) -> f32
...
%slice_7 = extract %input0[7] : tensor<8x16xf32> -> tensor<16xf32>
%max_7 = reduce { max } (%slice_7) -> f32
...
%slice_8 = extract %input1[0] : tensor<8x16xf32> -> tensor<16xf32>
%max_8 = reduce { max } (%slice_8) -> f32
%red_res = glue %max_0 %max_1 ... : tensor<16xf32>
In LLVM, this would be:
%input0 : <8 x float>
%input1 : <8 x float>
; 8 slices for each tensor
%el_0 = extractelement <8 x float> %input0, i64 0
%max_0 = call float @llvm.GenISA.WaveAll.f32(float %el_0, i8 12, i32 0)
; ...
%red_res = ; insertelement x 16
The issue with WaveAll
is that it compiles to (SIMD16, pseudocode):
; Split input 16-element vector in two. '.8' means 8 elements offset
mov (8|M0) rhs lhs.8
; Element-wise max, store in lhs
sel (8|M0) ge lhs lhs rhs
; Do the same again and again
mov (4|M0) rhs lhs.4
sel (4|M0) lhs lhs rhs
mov (2|M0) rhs lhs.2
sel (2|M0) lhs lhs rhs
mov (1|M0) lhs rhs.1
sel (1|M0) lhs lhs rhs
This results in 16 reductions of 8 operations each, meaning 2^7 operations. This is the case because we are not exploiting SIMD parallelism here, we're just reducing over the subgroup, underutilizing SIMD lines.
Ideally, we would like to do something like:
; Glued and transposed input matrices
%input : tensor<16x16xf32>
; 16 slices
%slice_0 = extract %input[0] : tensor<16x16xf32> -> tensor<16xf32>
...
; 15 ops to reduce each slice
%max_0 = max %slice_0, %slice_1
%max_1 = max %max_0, %slice_2
...
%max_t = glue %max_0, %max_1 -> tensor<16x1xf32>
Which, in LLVM, would be as simple as:
; Transpose original input
%input: <16 x float>
%el_0 = extractelement %input[0]
%el_1 = extractelement %input[1]
%max_0 = max %el_0, %el_1
; Just 15 max operations required!
; ...
%max = max %max_n1 %max_n2
; Transpose output
Here we could perform the same reduction in just 15 operations. Now, why does this "matrix transpose" work? Where is even the matrix? How can we transpose it?
Let's step back for a moment. Conceptually, we have a 16x16
matrix here distributed as follows across the subgroup:
+--------------------+---------+
| sub_group_local_id | indices |
+--------------------+---------+
| 0 | 0-15 |
| 1 | 16-31 |
| 2 | 32-47 |
| ... | ... |
+--------------------+---------+
In this matrix, reducing across the second dimension is very expensive, as we need to perform subgroup reduction. Doing a matrix transpose:
+--------------------+---------------------+
| sub_group_local_id | indices |
+--------------------+---------------------+
| 0 | 0,16,...,i*16,... |
| 1 | 1,17,...,i*16+1,... |
| 2 | 2,18,...,i*16+2,... |
| ... | ... |
+--------------------+---------------------+
Now we can reduce over the first dimension utilizing the SIMD lanes as we saw above.
We've seen the matrix and how nice it is to transpose it. However, how can we do that? This is yet to be explored. I'll build a POC for this. Possible solutions include going through local memory (probably fastest) or shufflevector + sub_group_shuffle
to do everything in registers.
By transposing it has extra overhead.
The IGC supports to mix the SIMD and SIMT expression if there was limitation in SIMT expression (and No IGC specific intrinsic to express the computation clearly. E.G: WaveAll take the vector type as the input.).
The invoke SIMD and inline VISA can help to generate micro operation with the SIMD extension.
And in this case, I think the invoke SIMD is more suitable.
And the SIMD 2D extension allows us to defines the strides of the operands of the SIMD binary operation. It can allow us to use the best parallelism of the SIMD.
https://github.com/intel/vc-intrinsics/blob/master/GenXIntrinsics/docs/GenXLangRef.rst#2d-region
The GENX intrinsics may be a nice option. However, our kernels model work-item behavior as of now, not EU behavior, so this would mean changing pretty much all we have as of now, right? I'd rather go with the local memory approach for now, get some numbers and see if that improves performance.
Yes, make sense. Let's check out what we can get with pure SIMT paradigm.
Prototype with SLM transpose taking shape. Overall structure finalized, missing last bits of lowering.
Pushed prototype. Shaping into what we want.
Couldn't work on POC much this week. DPAS output input to reduction has expected size now to exploit matrix transpose approach
Lots of progress in POC. Hitting one workaround (hopefully last one). Will try to work around it.
PoC ready to review. Will push latest perf report. Minor fixes to be pushed.
PR with latest needed code in llvm-target
ready for review and with one approval (#2109 ). Feature will be available behind an environment variable.
Looking into our FlashAttention-2 benchmark, we see that reduction codegen for XeTLA is better than Triton.
Investigate reasons and whether the kernel code or any pass in our pipeline has to be tuned.