triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.02k stars 1.59k forks source link

Low performance for mixed-type matmul due to lack of pipelining #1397

Open gflegar opened 1 year ago

gflegar commented 1 year ago

qWe are trying to generate a mixed precision matmul kernel with a quantized LHS, by modifying python/tutorials/03-matrix-multiplication.py (see patch below). However, on an A100 40GB, the quantized kernel achieves far worse performance (127 TFlop/s) than the unquantized version (227 TFlop/s).

Looking at the Triton GPU IR (from ~/.triton/cache/), it looks like the quantized argument is not pipelined, and there are unnecessary layout conversions between the type conversion and the dot operation. This is the first part of the main loop:

      %77 = tt.load %arg11 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xi8, #blocked>
      %78 = arith.sitofp %77 : tensor<128x32xi8, #blocked> to tensor<128x32xf16, #blocked>
      %79 = triton_gpu.convert_layout %78 : (tensor<128x32xf16, #blocked>) -> tensor<128x32xf16, #shared1>
      %80 = triton_gpu.convert_layout %79 : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>>
      %81 = triton_gpu.convert_layout %arg14 : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>>
      %82 = tt.dot %80, %81, %arg10 {allowTF32 = true} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma}>> -> tensor<128x256xf32, #mma>

Profiling the two kernels through NSight Compute seems to support that the lack of pipelining and layout conversions are the issue: there are additional memory transfers through L1TEX, additional writes to Shared memory, the number of cycles required to issue an instruction is increased by 77%, and "Stall Long Scoreboard" metric that measures how long instructions are stalled waiting on data from L1TEX went up from practically 0 to >3 cycles.

Is this a know issue, and by any chance already being addressed? If not, would this be something difficult to fix?


Patch that enables mixed precision in matmul tutorial (applies to 19e7238d):

diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py
index 8bcae2007..88029613a 100644
--- a/python/tutorials/03-matrix-multiplication.py
+++ b/python/tutorials/03-matrix-multiplication.py
@@ -163,14 +163,7 @@ import triton.language as tl
 #       provided configs
 @triton.autotune(
     configs=[
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
-        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
-        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
+        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4),
     ],
     key=['M', 'N', 'K'],
 )
@@ -231,6 +224,7 @@ def matmul_kernel(
         # Load the next block of A and B, generate a mask by checking the K dimension.
         # If it is out of bounds, set it to 0.
         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
+        a = a.to(tl.float16)
         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
         # We accumulate along the K dimension.
         accumulator += tl.dot(a, b)
@@ -272,7 +266,7 @@ def matmul(a, b, activation=""):
     M, K = a.shape
     K, N = b.shape
     # Allocates output.
-    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
+    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
     # 1D launch kernel where each block gets its own program.
     grid = lambda META: (
         triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
@@ -294,11 +288,14 @@ def matmul(a, b, activation=""):
 #
 # We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).

+# lhs_type = torch.float16
+lhs_type = torch.int8
+
 torch.manual_seed(0)
-a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
+a = torch.randn((512, 512), device='cuda', dtype=torch.float16).to(lhs_type)
 b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
 triton_output = matmul(a, b)
-torch_output = torch.matmul(a, b)
+torch_output = torch.matmul(a.to(torch.float16), b)
 print(f"triton_output={triton_output}")
 print(f"torch_output={torch_output}")
 if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
@@ -336,11 +333,11 @@ else:
     )
 )
 def benchmark(M, N, K, provider):
-    a = torch.randn((M, K), device='cuda', dtype=torch.float16)
+    a = torch.randn((M, K), device='cuda', dtype=torch.float16).to(lhs_type)
     b = torch.randn((K, N), device='cuda', dtype=torch.float16)
     quantiles = [0.5, 0.2, 0.8]
     if provider == 'cublas':
-        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
+        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a.to(torch.float16), b), quantiles=quantiles)
     if provider == 'triton':
         ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles)
     perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
ptillet commented 1 year ago

Aware of this and working on that. It's a substantial change that required significant refactoring of the backend

gflegar commented 1 year ago

Thanks, do you have an ETA already when this is expected to get fixed?

ptillet commented 1 year ago

Maybe a couple of weeks. It's challenging cause when you're in mixed precision mode you can't just used ldmatrix on both operands.

cheshire commented 1 year ago

@ptillet if we want to start looking at this as well, do you have any pointers on where to start? If in general we want to be more involved, are there avenues we could use (calls/chats/etc)?

ptillet commented 1 year ago

This still requires a fair amount of refactoring, and I'm actually in the middle of it :p main challenge is that, in a mixed precision setting, both args can't be fetched with ldmatrix. Current backend codegen doesn't support well fetching tensor core arguments with LDS

ptillet commented 1 year ago

If in general you want to be more involved, can you slack me or email me at phil@openai.com ?

gflegar commented 1 year ago

Is it actually mixed precision that's the problem, or more that int8 causes the problem?

Because if it's mixed precision we could still do a "quick" fix that doesn't require codegen changes:

  1. pipeline loading from global -> shared for LHS
  2. before the matmul, load LHS from shared to regs in blocked layout
  3. convert to bf16
  4. store to shared, and load in mma layout
  5. do matmul

With this, at least the mma is not blocked on global loads.

ptillet commented 1 year ago

My hunch is that it's unlikely to be faster than not pipelining due to increased shared mem traffic. I think currently we get like 200TFLOPS on A100 with mixed precision without pipelining. I'm more confident in the mixed ldmatrix/lds approach, which hopefully I'll wrap up by the end of this week or next week. :)

gflegar commented 1 year ago

@ptillet any updates on this?

ptillet commented 1 year ago

necessary refactor is nearing completion but there are some shapes/dtype/num-warps combination for which lds still doesn't match the behavior of ldmatrix

gflegar commented 1 year ago

Thanks for the update! In the meantime we're attempting something similar to what I outlined in https://github.com/openai/triton/issues/1397#issuecomment-1490799517, but we're hitting an issue because the shared->blocked layout conversion is not implemented.

This seems not too difficult to add at first glance (since we already have a blocked->shared conversion as a template). Would contributing shared->blocked layout conversion be interesting for upstream?

ptillet commented 1 year ago

This seems not too difficult to add at first glance (since we already have a blocked->shared conversion as a template). Would contributing shared->blocked layout conversion be interesting for upstream?

potentially. I think the use cases for shared -> blocked are quite limited, but it is true that this code path is implicit in the blocked -> blocked conversion (it effectively does `blocked -> shared -> block). Maybe I misremember but I thought the two steps were actually already decoupled. I am also actually not sure that the alternative you proposed is going to result in significantly better performance than the non-pipelined version.

FYI, #1557 implements the hardest part of the refactor. It's a small PR but it took me a while because I got sidetracked a few times and also it's pretty challenging to get the logic right (I'm not 100% sure but it passed all the tests I threw at it). the rest is just pretty much plumbing at this point and I'm confident I'll be able to get it done by early next week.

ptillet commented 1 year ago

I got over the weekend a FP8 x FP8 kernel that clocks at 260TFLOPS on A100, but my mixed FP16 x FP8 is still at 220 (vs 195 without pipelining...). I'll have to spend some time investigating where the perf hit comes from before starting the plumbing work.

ptillet commented 1 year ago

I think this should be (finally) solved on main. At least I tried float8e5 x float16 matmul and it worked properly. Beware though that only row x col mode will be fast due to ldmatrix limitations. Also beware that optimal tile size seems to be 128x128x64, num_stages=3, num_warps=4. My float8e5 x float16 kernel gets 220TFLOPS, but I am working with nvidia to resolve what I think is an issue in the ptxas optimizer that causes the slowdown. I don't think anything more can be done at the Triton level to improve perf.... but as usual I'm happy to be proven wrong :)

Sorry for the time it took, but this opened a big can of worms that required multiple refactors deep in the stack.

gflegar commented 1 year ago

Thanks! I tried this with the patch to matmul-tutorial from the description of this issue (updated it to apply to latest Triton head), but it unfortunately fails for me, when running:

CUDA_LAUNCH_BLOCKING=1 python3 ./tutorials/03-matrix-multiplication.py

I get the following error:

Traceback (most recent call last):
  File "<string>", line 22, in matmul_kernel
KeyError: ('2-.-1-.-0--7d062dd833c9fd2e3a10cb919dff9593-d98f5ddab8f235cff4818270ea3be930-f32ad5a69e3017a5098d3d5532499369-cacecb5a01b695fe1eb376e18972d557-06b47813aaed5d9f42e68c0b9c8e48c0-1f9a5def50648a25b80eced738f870ed-30b6e6b1d71fdbb8ff5fc6b8488466bc-28cf33ff2045e5ba4563704dbc5615cc-84d6008a7b72daa9acf1b72e2349b605-293a6a48c2635776c9dee8b54603306f-aaf65dd0e0e16f803cf5f7b162164322-22daff6db48896bd8b64e2af96e88d58', (torch.int8, torch.float16, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 128, 64, 8, ''), (True, True, True, (True, False), (True, False), (True, False), (True, False), (False, True), (True, False), (False, True), (True, False), (False, True)), 4, 3, False)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/google/home/gflegar/oss/triton/src/python/./tutorials/03-matrix-multiplication.py", line 297, in <module>
    triton_output = matmul(a, b)
  File "/usr/local/google/home/gflegar/oss/triton/src/python/./tutorials/03-matrix-multiplication.py", line 274, in matmul
    matmul_kernel[grid](
  File "/usr/local/google/home/gflegar/oss/triton/src/python/triton/runtime/autotuner.py", line 110, in run
    return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
  File "<string>", line 44, in matmul_kernel
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

After running it through NVIDIA's compute-sanitizer, I get 400+ errors like the following:

========= COMPUTE-SANITIZER
========= Invalid __shared__ read of size 2 bytes
=========     at 0x2770 in matmul_kernel_0d1d2d3d4d5d6d7c8d9c10d11c
=========     by thread (34,0,0) in block (0,0,0)
=========     Address 0xf600 is out of bounds

This was done with int8 x float6. I'm not sure how you are setting up float8e5, since torch doesn't have support for that AFAICT, but I also just tried doing a triton.reinterpret(a, tl.float8e5) of my int8 matrix just before launching the kernel (line 275 of matmul tutorial at commit 19e7238) to make the kernel operate on fp8, but I get the exact same compute sanitizer error.

Can you reproduce the same thing?

This was done on my local RTX 3090 Ti (that was quickest for me to try), but I can also run it on A100 in case that helps.

gflegar commented 1 year ago

I ran it on an A100-SXM4-40GB, and got the same invalid shared access of 2 bytes. The threads that access it the wrong way seem different though, but maybe it's just ~all threads, and the ones we get reported are not deterministic, here are the threads that caused invalid accesses in the A100 run:

=========     by thread (67,0,0) in block (0,0,0)
=========     by thread (71,0,0) in block (0,0,0)
=========     by thread (75,0,0) in block (0,0,0)
=========     by thread (79,0,0) in block (0,0,0)
=========     by thread (83,0,0) in block (0,0,0)
=========     by thread (87,0,0) in block (0,0,0)
=========     by thread (91,0,0) in block (0,0,0)
=========     by thread (95,0,0) in block (0,0,0)
=========     by thread (99,0,0) in block (0,0,0)
=========     by thread (103,0,0) in block (0,0,0)
=========     by thread (107,0,0) in block (0,0,0)
=========     by thread (111,0,0) in block (0,0,0)
=========     by thread (115,0,0) in block (0,0,0)
=========     by thread (119,0,0) in block (0,0,0)
=========     by thread (123,0,0) in block (0,0,0)
=========     by thread (127,0,0) in block (0,0,0)
=========     by thread (3,0,0) in block (0,0,0)
=========     by thread (7,0,0) in block (0,0,0)
=========     by thread (11,0,0) in block (0,0,0)
=========     by thread (15,0,0) in block (0,0,0)
=========     by thread (19,0,0) in block (0,0,0)
=========     by thread (23,0,0) in block (0,0,0)
=========     by thread (27,0,0) in block (0,0,0)
=========     by thread (31,0,0) in block (0,0,0)
=========     by thread (3,0,0) in block (1,0,0)
=========     by thread (7,0,0) in block (1,0,0)
=========     by thread (11,0,0) in block (1,0,0)
=========     by thread (15,0,0) in block (1,0,0)
=========     by thread (19,0,0) in block (1,0,0)
=========     by thread (23,0,0) in block (1,0,0)
=========     by thread (27,0,0) in block (1,0,0)
=========     by thread (31,0,0) in block (1,0,0)
=========     by thread (35,0,0) in block (1,0,0)
=========     by thread (39,0,0) in block (1,0,0)
=========     by thread (43,0,0) in block (1,0,0)
=========     by thread (47,0,0) in block (1,0,0)
=========     by thread (51,0,0) in block (1,0,0)
=========     by thread (55,0,0) in block (1,0,0)
=========     by thread (59,0,0) in block (1,0,0)
=========     by thread (63,0,0) in block (1,0,0)
=========     by thread (67,0,0) in block (1,0,0)
=========     by thread (71,0,0) in block (1,0,0)
=========     by thread (75,0,0) in block (1,0,0)
=========     by thread (79,0,0) in block (1,0,0)
=========     by thread (83,0,0) in block (1,0,0)
=========     by thread (87,0,0) in block (1,0,0)
=========     by thread (91,0,0) in block (1,0,0)
=========     by thread (95,0,0) in block (1,0,0)
=========     by thread (99,0,0) in block (1,0,0)
=========     by thread (103,0,0) in block (1,0,0)
=========     by thread (107,0,0) in block (1,0,0)
=========     by thread (111,0,0) in block (1,0,0)
=========     by thread (115,0,0) in block (1,0,0)
=========     by thread (119,0,0) in block (1,0,0)
=========     by thread (123,0,0) in block (1,0,0)
=========     by thread (127,0,0) in block (1,0,0)
=========     by thread (99,0,0) in block (2,0,0)
=========     by thread (103,0,0) in block (2,0,0)
=========     by thread (107,0,0) in block (2,0,0)
=========     by thread (111,0,0) in block (2,0,0)
=========     by thread (115,0,0) in block (2,0,0)
=========     by thread (119,0,0) in block (2,0,0)
=========     by thread (123,0,0) in block (2,0,0)
=========     by thread (127,0,0) in block (2,0,0)
=========     by thread (3,0,0) in block (2,0,0)
=========     by thread (7,0,0) in block (2,0,0)
=========     by thread (11,0,0) in block (2,0,0)
=========     by thread (15,0,0) in block (2,0,0)
=========     by thread (19,0,0) in block (2,0,0)
=========     by thread (23,0,0) in block (2,0,0)
=========     by thread (27,0,0) in block (2,0,0)
=========     by thread (31,0,0) in block (2,0,0)
=========     by thread (35,0,0) in block (2,0,0)
=========     by thread (39,0,0) in block (2,0,0)
=========     by thread (43,0,0) in block (2,0,0)
=========     by thread (47,0,0) in block (2,0,0)
=========     by thread (51,0,0) in block (2,0,0)
=========     by thread (55,0,0) in block (2,0,0)
=========     by thread (59,0,0) in block (2,0,0)
=========     by thread (63,0,0) in block (2,0,0)
=========     by thread (67,0,0) in block (2,0,0)
=========     by thread (71,0,0) in block (2,0,0)
=========     by thread (75,0,0) in block (2,0,0)
=========     by thread (79,0,0) in block (2,0,0)
=========     by thread (83,0,0) in block (2,0,0)
=========     by thread (87,0,0) in block (2,0,0)
=========     by thread (91,0,0) in block (2,0,0)
=========     by thread (95,0,0) in block (2,0,0)
=========     by thread (3,0,0) in block (3,0,0)
=========     by thread (7,0,0) in block (3,0,0)
=========     by thread (11,0,0) in block (3,0,0)
=========     by thread (15,0,0) in block (3,0,0)
=========     by thread (19,0,0) in block (3,0,0)
=========     by thread (23,0,0) in block (3,0,0)
=========     by thread (27,0,0) in block (3,0,0)
=========     by thread (31,0,0) in block (3,0,0)
=========     by thread (35,0,0) in block (3,0,0)
=========     by thread (39,0,0) in block (3,0,0)
=========     by thread (43,0,0) in block (3,0,0)
=========     by thread (47,0,0) in block (3,0,0)
ptillet commented 1 year ago

I see! https://github.com/openai/triton/blob/phil/float8-perf-repro/python/tutorials/03-matrix-multiplication.py this is how the float16 x float8e5 is setup. Could you try to get a diff of the kernel and launch config between this version and yours?

Also: now that all the refactoring is finally done, I think you guys should feel free to dig deep into this issue. I could assist.

ptillet commented 1 year ago

It's possible it's related to auto-tuning, i've only tried with big block sizes. I wouldn't be surprised if you ran into this error for very small block sizes

cheshire commented 1 year ago

I get an OOB read in this config: MatmulTiling(BLOCK_M=32, BLOCK_N=128, BLOCK_K=32, SPLIT_K=1, num_stages=2, num_warps=4) for MatmulSize(M=1536, N=2048, K=12288, quantized_lhs=1), so the tiles are all rather large.

Maybe we could add some more unit tests to Triton to avoid such crashes?

gflegar commented 1 year ago

Got pulled into something else last week, so couldn't look at this closer, but should be able to focus on fixing this now.

I used the same shapes as you did when I got this error.

Looking at your example, one difference is that in your case the low-precision matrix is the RHS, and in mine the LHS.

You're also loading the LHS matrix row- and the RHS matrix column-major IIUC what the T operator does.

I tried my s8[512x512] x f16[512x512] matmul with various combinations of no-transpose (N) vs transpose (T):

I assume that the bug happens in all cases where the low-precision matrix is not in column-major order.

gflegar commented 1 year ago

I ran the matmul tutorial benchmark (modification from gflegar/triton@a412dcc) on A100-40GB with all the different combinations of low/high precision and row/column major layouts to see what performance we get, and where there are problems.

There are a few configurations where we crash with the new version of Triton, and a few where we get wrong results. I could also not reproduce the claimed performance improvement on int8, I could only get close to cuBLAS with float8e5 (FNST configuration that @ptillet used in their example gets 205 TF/s vs 228 TF/s for cuBLAS, but it's only 138 TF/s with int8).

@ptillet do we somewhere in codegen rely on having the specific float8e5 -> float16 conversion, which could explain why we're not seeing the same benefit on int8?

Config M=N=K a9c87245b cuBLAS S=int8 (TF/s) a9c87245b Triton S=int8 (TF/s) 35b27e1ee Triton S=int8 (TF/s) 35b27e1ee Triton S=float8e5 (TF/s)
FNSN 4096 235 157 156 163
FNST 4096 228 161 138 205
FTSN 4096 236 158 144 155
FTST 4096 235 157 Wrong result Wrong result
SNFN 4096 235 163 Illegal access* Illegal access*
SNFT 4096 228 176 Wrong result Wrong result
STFN 4096 236 162 153 163
STFT 4096 235 178 171 190
FNFN 4096 268 264 243 N/A
FNFT 4096 251 257 245 N/A
FTFN 4096 269 269 242 N/A
FTFT 4096 263 263 246 N/A
SNSN 4096 207 142 Assertion** Assertion**
SNST 4096 202 138 105 241
STSN 4096 208 139 133 164
STST 4096 207 144 Assertion** Assertion**

Legend for "Config" column:

The cuBLAS column is running separate kernels first to convert S -> F if needed, and then running an F*F* kernel. Both are included in the timing.

* RuntimeError: CUDA error: an illegal memory access was encountered ** python3: /home/gflegar/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp:209: mlir::LogicalResult {anonymous}::Prefetcher::initialize(): Assertion 'aKWidth == bKWidth' failed.

ptillet commented 1 year ago

Thanks for the report. It's helpful. I think the low performance you're seeing comes from lack of int8x4 to float16x4 optimized code path in the LLVM code gen. This would be quite superficial and a pretty easy fix. The wrong results are worrisome and I'll take a look.

ptillet commented 1 year ago

yeah, I think there are some issues in the #shared -> #dotOperand<mmav2, kWidth=4> conversion code path for FP16. I will look into it. That explains the wrong result for SNFT and FTST I believe. Separate issue is that the logic that disables the new optimization for things other than NT layout does seem to work well, leading to the runtime errors that you're seeing.

ptillet commented 1 year ago

I have a temporary fix that works for SNFT, SNST and FNST in a branch of mine. I still need to think more about it though, but I'm confident i'll be able to push a fix tomorrow.

ptillet commented 1 year ago

https://github.com/openai/triton/pull/1695 should make mixed precision float16-float8 and float8-float16 work in all cases. This does not address int8 performance issues, which I suspect come from non-vectorized SIMD ops.

Also, note that performance will only be good for the row-col layout due to the fact that ldmatrix can't transpose b8 inputs.

gflegar commented 1 year ago

Thanks! I just verified that all the configs that had issues before now work on 034195346.

Though I think the fix introduced a new bug - the SNST config now fails with the same assertion that SNSN and STST did before (though those two now work), with both float8 and int8 :

/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp:208: mlir::LogicalResult (anonymous namespace)::Prefetcher::initialize(): Assertion `aKWidth == bKWidth' failed.

Maybe something to take a look at, though I don't think that combination matters a lot to us at this point.

I'll take a look at the vectorized converts for int8x4.

gflegar commented 1 year ago

@ptillet now that the casts are optimized as far as we can (it could be possible to still generate better SASS for those to produce 25% fewer instructions - we're talking to NVIDIA about that), I'm wondering if you think that the row-col restriction for using ldmatrix could be lifted here?

There's a coupe of folks on our side (@chsigg, @pifon2a) in the process of understanding how Triton is using the ldmatrix instruction in the presence of cast between ldmatrix and mma. Looking at the instruction, it always assumes that the elements are 16 bit. What tricks does Triton end up using to still get the correct fragments for the mma instruction, even if the original element is not 16 bit wide, and if we end up casting them to a type that requires a different fragment layout?

What breaks if we want to also use the .trans modifier in combination with casting?

(Also @ThomasRaoux since I've seen them working on related code in the repo.)

ptillet commented 1 year ago

Well triton can use ldmatrix for 8-bit x 8-bit, but as you said for 8-bit x 16-bit triton would use lds for the 16-bit operand :)

gflegar commented 1 year ago

Yes, that's part of my question actually: what does Triton do to be able to use ldmatrix even for 8-bit x 8-bit, since ldmatrix only has a .b16 version for a 16-bit data type?

HDCharles commented 1 year ago

Hey, I'm trying to enable weight-only gpu quantization in pytorch and we're trying to use triton kernels for the matmul.

it sounded like https://github.com/openai/triton/pull/1879 should have significantly improved the performance of int8 -> bf16 cast but i'm not sure if its not improving for my usecase or if I'm doing something wrong. I'm comparing a cublas bf16 linear to a triton kernel where weight is cast to bf16 (and scale is applied after the accumulation) but i'm seeing the triton kernel take ~2x as long when matrix size increases (though for small matrices it seems comparable). There's also a triton bf16xbf16 kernel version that is trying to determine whether its the cast that's taking a long time (seems like it is since this one achieves close to cublas perf despite rescale and bias)

https://gist.github.com/HDCharles/9f783e7ae3531127e8a2233760b52a65

is that expected? I would have thought that in a vacuum it would go as fast or a bit faster depending on the gains from the int8 load time vs loss from cast time. Do i need to do something in particular to enable https://github.com/openai/triton/pull/1879? I'm on an A100 with cuda 12.1

table data | X . W | X . W.t() | X.t() . W | X.t() . W.t() | kernel | M=K=N | |--------|--------|--------|--------|--------------|--------| | 0.0102 | 0.0113 | 0.0102 | 0.0113 | cublas linear | 256 | | 0.0102 | 0.0102 | 0.0102 | 0.0102 | bfloat16 linear | 256 | | 0.0102 | 0.0102 | 0.0123 | 0.0113 | int8 linear | 256 | | 0.0154 | 0.0143 | 0.0154 | 0.0133 | uint4x2 linear | 256 | | 0.0236 | 0.0236 | 0.0246 | 0.0236 | cublas linear | 1024 | | 0.0225 | 0.0225 | 0.0225 | 0.0225 | bfloat16 linear | 1024 | | 0.0338 | 0.0532 | 0.0338 | 0.0348 | int8 linear | 1024 | | 0.0430 | 0.0481 | 0.0451 | 0.0492 | uint4x2 linear | 1024 | | 0.5847 | 0.5868 | 0.5847 | 0.5816 | cublas linear | 4096 | | 0.6113 | 0.5929 | 0.5929 | 0.6001 | bfloat16 linear | 4096 | | 0.9713 | 1.2657 | 0.9626 | 0.9605 | int8 linear | 4096 | | 1.0568 | 1.1950 | 1.0977 | 1.1868 | uint4x2 linear | 4096 |
sergachev commented 1 year ago

@HDCharles are you sure that you are using the fastest config in each case? They are typically quite different from case to case and unlikely to be present in the ~10 that are usually copied in all Triton GEMM examples inside @triton.autotune.

HDCharles commented 1 year ago

@sergachev

Yeah, to confirm I edited the triton matmul.py file slightly do int8 -> bfloat16 since its already setup to cast a and b, I just needed to edit it to pick the correct type. I also used a similar config generation scheme for my kernels. No real speed difference though.

https://gist.github.com/HDCharles/d1bd21b07748ce58ee2d1f5b2d487710

also it looks like in ~/.triton/cache/.../int8_weight_only_linear_kernel.ptx I see the lines https://github.com/openai/triton/pull/1879/files#diff-6f55a62fcb89f319e5988dcabc8aa874f510220fd8f21dcb8b2a3d53fc8c9551R230-R239 in https://gist.github.com/HDCharles/d50b5da7365ea7f256a00c771e7fc0ec

so that means the changes are getting applied, right? https://gist.github.com/HDCharles/d50b5da7365ea7f256a00c771e7fc0ec

still the cublas kernel is about 2x faster compared to all the triton kernels doing int8 -> bfloat16 for larger matrices. Is that result surprising given the casting fix or is this expected?

table of data | X . W | X . W.t() | X.t() . W | X.t() . W.t() | model | (M, N, K) | config | |- | - | - | - | - | - | - | | 0.0092 | 0.0102 | 0.0113 | 0.0102 | cublas linear | (256, 256, 256) | None | | 0.0123 | 0.0102 | 0.0123 | 0.0133 | triton matmul | (256, 256, 256) | BLOCK_M: 16, BLOCK_N: 128, BLOCK_K: 64, SPLIT_K: 1, num_warps: 4, num_stages: 4 | | 0.0123 | 0.0092 | 0.0123 | 0.0123 | int8 linear | (256, 256, 256) | BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 4 | | 0.0113 | 0.0123 | 0.0123 | 0.0133 | uint4x2 linear | (256, 256, 256) | BLOCK_SIZE_M: 16, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 2, num_stages: 5 | | 0.0246 | 0.0236 | 0.0236 | 0.0236 | cublas linear | (1024, 1024, 1024) | None | | 0.0573 | 0.0389 | 0.0563 | 0.0573 | triton matmul | (1024, 1024, 1024) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 | | 0.0389 | 0.0348 | 0.0399 | 0.0379 | int8 linear | (1024, 1024, 1024) | BLOCK_SIZE_M: 32, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 3 | | 0.0440 | 0.0389 | 0.0451 | 0.0410 | uint4x2 linear | (1024, 1024, 1024) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 32, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | 0.5929 | 0.5827 | 0.5898 | 0.5888 | cublas linear | (4096, 4096, 4096) | None | | 1.0793 | 0.9318 | 1.0547 | 1.0291 | triton matmul | (4096, 4096, 4096) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 | | 0.9871 | 0.9467 | 1.0527 | 0.9871 | int8 linear | (4096, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | 1.0127 | 1.0086 | 1.0793 | 1.0281 | uint4x2 linear | (4096, 4096, 4096) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | 36.4964 | 36.0643 | 36.0300 | 35.5374 | cublas linear | (16384, 16384, 16384) | None | | 63.9580 | 57.1423 | 63.4399 | 62.9699 | triton matmul | (16384, 16384, 16384) | BLOCK_M: 128, BLOCK_N: 128, BLOCK_K: 32, SPLIT_K: 1, num_warps: 4, num_stages: 4 | | 65.0670 | 61.5434 | 66.1012 | 62.3503 | int8 linear | (16384, 16384, 16384) | BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 128, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 | | 61.3018 | 63.1890 | 63.6826 | 62.2889 | uint4x2 linear | (16384, 16384, 16384) | BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 64, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 8, num_warps: 4, num_stages: 4 |
manishucsd commented 1 year ago

https://github.com/openai/triton/issues/1397#issuecomment-1545744014

I am seeing similar-ish behaviour on F16*F16, i.e.,

I have posted my reproducer on slack here

jon-chuang commented 1 year ago

@manishucsd there is some related investigation here:

May I ask if you are on latest Triton main? I observed that the transpose issues went away for my use-case once updating to main.

I think Google might have been using Triton from about 1.5 months back?

Could you also post your reproducer here?