intel / intel-xpu-backend-for-triton

OpenAI Triton backend for Intel® GPUs
MIT License
100 stars 28 forks source link

Reland [TUTORIAL] persistent softmax kernel #1495

Open victor-eds opened 1 week ago

victor-eds commented 1 week ago

Reland 01c3e984490cbff3164ef97bf79cbd36628281dc, a263360050e1887a2cda0c2cac811ddd3ccaab1e, a5b32a8718b64786290a030c289c899554364e53 and 8ffdec13e6c36e12c378c3e4ffb6df5ef27150e5.

These commits introduce tuning for NVIDIA GPUs. Modify for better tuning for XPU devices:

Code calculating occupancy based on https://oneapi-src.github.io/oneAPI-samples/Tools/GPU-Occupancy-Calculator/

Closes #1099

victor-eds commented 1 week ago

softmax-performance vs upstream: upstream

Kept subgroup size heuristics, but using 32 instead of 16 for bigger sizes as that gives better performance. Upstream constant size did not give good results at all. No need to fine tune too much as, according to the tutorial:

    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
victor-eds commented 1 week ago

Difference w.r.t. upstream:

diff -u --label \#\<buffer\ 02-fused-softmax.py\<triton\>\> --label \#\<buffer\ 02-fused-softmax.py\<intel-xpu-backend-for-triton\>\> /tmp/buffer-content-rKiB7q /tmp/buffer-content-wpJgCi
--- #<buffer 02-fused-softmax.py<triton>>
+++ #<buffer 02-fused-softmax.py<intel-xpu-backend-for-triton>>
@@ -22,6 +22,7 @@
 # Let us consider instead the case of a simple (numerically stabilized) softmax operation:

 import torch
+import intel_extension_for_pytorch  # type: ignore # noqa: F401

 import triton
 import triton.language as tl
@@ -71,12 +72,12 @@

 @triton.jit
-def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
-                   num_stages: tl.constexpr):
+def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols,
+                   BLOCK_SIZE: tl.constexpr):
     # starting row of the program
     row_start = tl.program_id(0)
     row_step = tl.num_programs(0)
-    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
+    for row_idx in tl.range(row_start, n_rows, row_step):
         # The stride represents how much we need to increase the pointer to advance 1 row
         row_start_ptr = input_ptr + row_idx * input_row_stride
         # The block size is the next power of two greater than n_cols, so we can fit each
@@ -101,30 +102,40 @@
 # %%
 # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

-device = torch.cuda.current_device()
+device = torch.xpu.current_device()
 properties = driver.active.utils.get_device_properties(device)
 NUM_SM = properties["multiprocessor_count"]
-NUM_REGS = properties["max_num_regs"]
-SIZE_SMEM = properties["max_shared_mem"]
-WARP_SIZE = properties["warpSize"]
+WARPS_PER_EU = 8  # TODO: Get from properties
+EU_PER_SM = 8  # TODO: Get from properties
+MAX_NUM_WG = 64  # TODO: Get from properties
+WARP_SIZE = properties["sub_group_sizes"][-1]
+WG_SIZE = properties["max_work_group_size"]
+max_num_warps = WG_SIZE // WARP_SIZE
 target = triton.runtime.driver.active.get_current_target()
+warps_per_sm = WARPS_PER_EU * EU_PER_SM
+max_num_resident_warps = NUM_SM * warps_per_sm
+slm_size_per_sub_slice = 128  # TODO: Get from properties
 kernels = {}

 def softmax(x):
-    n_rows, n_cols = x.shape

+    def occupancy(num_warps, size_smem):
+        num_wg_threads = warps_per_sm // num_warps
+        num_wg_slm = MAX_NUM_WG if size_smem == 0 else slm_size_per_sub_slice // size_smem
+        num_wg = min(num_wg_threads, num_wg_slm, MAX_NUM_WG)
+        return NUM_SM * num_wg
+
+    n_rows, n_cols = x.shape
     # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
     BLOCK_SIZE = triton.next_power_of_2(n_cols)

-    # Another trick we can use is to ask the compiler to use more threads per row by
-    # increasing the number of warps (`num_warps`) over which each row is distributed.
+    # Simple heuristic depending on `BLOCK_SIZE`. We aim for 16 elements per
+    # thread. As the maximum number of warps is limited by hardware, we need to
+    # make sure we do not surpass that limit.
     # You will see in the next tutorial how to auto-tune this value in a more natural
     # way so you don't have to come up with manual heuristics yourself.
-    num_warps = 8
-
-    # Number of software piepling stages.
-    num_stages = 4 if SIZE_SMEM > 200000 else 2
+    num_warps = min(max_num_warps, max(1, BLOCK_SIZE // (WARP_SIZE * 16)))

     # Allocate output
     y = torch.empty_like(x)
@@ -132,27 +143,25 @@
     # pre-compile kernel to get register usage and compute thread occupancy.
     kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
     if kernel is None:
-        kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
-                                       num_stages=num_stages, num_warps=num_warps, grid=(1, ))
+        kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, num_warps=num_warps,
+                                       threads_per_warp=WARP_SIZE, BLOCK_SIZE=BLOCK_SIZE, grid=(1, ))
         kernel._init_handles()
-        n_regs = kernel.n_regs
         size_smem = kernel.metadata.shared
-        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
-        occupancy = min(occupancy, SIZE_SMEM // size_smem)
-        num_programs = NUM_SM * occupancy
+        # If we cannot reach maximum occupancy, we will not go with persistent programs and schedule a program per row.
+        # Occupancy could be maximized by tweaking `num_warps` and `threads_per_warp`, but it is worth remembering
+        # higher occupancy does not always translate to better performance.
+        # Persistent kernels may show better performance when the occupancy is 100 %, but this may not be the case in
+        # other cases, as work-group preemption will help hide stall GPU cycles.
+        num_programs = occupancy(num_warps, size_smem)
+        if num_programs * num_warps < max_num_resident_warps:
+            num_programs = n_rows
         kernels[BLOCK_SIZE] = (kernel, num_programs)

     num_programs = min(num_programs, n_rows)

     # Create a number of persistent programs.
-    kernel[(num_programs, 1, 1)](
-        y,
-        x,
-        x.stride(0),
-        y.stride(0),
-        n_rows,
-        n_cols,
-    )
+    kernel[(num_programs, )](y, x, x.stride(0), y.stride(0), n_rows, n_cols)
     return y

@@ -165,7 +174,7 @@
 # This will allow us to verify that our padding mechanism works.

 torch.manual_seed(0)
-x = torch.randn(1823, 781, device='cuda')
+x = torch.randn(1823, 781, device='xpu')
 y_triton = softmax(x)
 y_torch = torch.softmax(x, axis=1)
 assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
@@ -197,9 +206,9 @@
         args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
     ))
 def benchmark(M, N, provider):
-    x = torch.randn(M, N, device='cuda', dtype=torch.float32)
-    stream = torch.cuda.Stream()
-    torch.cuda.set_stream(stream)
+    x = torch.randn(M, N, device='xpu', dtype=torch.float32)
+    stream = torch.xpu.Stream()
+    torch.xpu.set_stream(stream)
     if provider == 'torch':
         ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
     if provider == 'triton':

Diff finished.  Tue Jul  2 16:24:16 2024
victor-eds commented 6 days ago

Further tuned for better performance in smaller BLOCK_SIZE values: softmax-performance

victor-eds commented 23 hours ago

Non-persistent run gives better performance for n_cols < 1024: that's the 100 % occupancy threshold, when we pass that, we cannot fill the GPU for this warp-size and we get worse performance from persistent kernels than if we oversubscribe: softmax-performance

Why is that? Let's see the occupancy and stall report:

When using persistent kernels here, we can only fully utilize 50 % of our GPU (this shows how good I am and how the method to calculate occupancy works :wink:), but, ofc, we're missing on half of the HW capabilities here: persistent

For the non-persistent run, we see how we get >90 % occupancy for most of the run! This shows we do not always want to run persistent kernels as PVC is doing threads preemption and using those stall cycles we miss on when using persistent kernels

non-persistent

If we focus on the smaller block sizes, where persistent kernels are giving better performance...

We see how the non-persistent kernel reaches 100 % occupancy for a while and then drops (resolution is a bit slow, so short-runtime does not help much here). We need more than 1 wave (one of them at suboptimal occupancy):

non-persistent

For the persistent case, we save thread spawning overhead as we use a single 100 % occupancy wave (again, low res of tracing and short runtime, but more or less we get the idea):

persistent

For this reason, I added a check in the code to see if we're gonna be underutilising the GPU and issue n_rows programs in that case, as that'll give the best performance on target HW.