Closed ravil-mobile closed 3 weeks ago
Here are some intermediate results
M=N=K : 1024
args.atol=0.01
args.rtol=0.0001
triton_output=tensor([[ 1.2031, 37.3750, -29.4062, ..., 43.4688, 93.7500, -15.1328],
[-11.3516, 23.7031, -3.3965, ..., 39.2812, -15.3438, -79.2500],
[ -4.6289, 30.0000, -42.2500, ..., -5.1094, 53.6875, 16.0000],
...,
[-85.3125, 3.0156, 20.2031, ..., -30.9844, -30.5156, -44.1875],
[ 44.8438, -21.9688, 10.4141, ..., 72.0000, 6.9492, -1.6816],
[-37.9375, -8.3750, 41.1250, ..., -12.8750, 38.2188, -22.5469]],
device='cuda:0')
torch_output=tensor([[ 1.2027, 37.4007, -29.4260, ..., 43.4901, 93.8217, -15.1370],
[-11.3525, 23.7217, -3.4030, ..., 39.3042, -15.3559, -79.3293],
[ -4.6234, 30.0324, -42.3040, ..., -5.1261, 53.7173, 16.0252],
...,
[-85.3604, 3.0140, 20.1974, ..., -30.9894, -30.5337, -44.1995],
[ 44.8633, -21.9839, 10.4117, ..., 72.0516, 6.9408, -1.6693],
[-37.9492, -8.3891, 41.1487, ..., -12.8913, 38.2381, -22.5666]])
numpy_output=tensor([[ 1.2026, 37.4007, -29.4260, ..., 43.4901, 93.8217, -15.1370],
[-11.3525, 23.7217, -3.4030, ..., 39.3042, -15.3559, -79.3293],
[ -4.6234, 30.0324, -42.3040, ..., -5.1262, 53.7173, 16.0252],
...,
[-85.3604, 3.0140, 20.1974, ..., -30.9894, -30.5337, -44.1995],
[ 44.8633, -21.9839, 10.4117, ..., 72.0516, 6.9408, -1.6693],
[-37.9492, -8.3891, 41.1487, ..., -12.8913, 38.2381, -22.5666]])
naive_output=tensor([[ 1.2027, 37.4007, -29.4260, ..., 43.4901, 93.8217, -15.1370],
[-11.3525, 23.7217, -3.4030, ..., 39.3042, -15.3559, -79.3293],
[ -4.6234, 30.0324, -42.3040, ..., -5.1261, 53.7173, 16.0252],
...,
[-85.3604, 3.0140, 20.1974, ..., -30.9894, -30.5337, -44.1995],
[ 44.8633, -21.9839, 10.4117, ..., 72.0516, 6.9408, -1.6693],
[-37.9492, -8.3891, 41.1487, ..., -12.8913, 38.2381, -22.5667]])
M=N=K : 1024
args.atol=0.01
args.rtol=0.0001
triton_output=tensor([[-13.5234, -28.3125, 18.6562, ..., 76.5000, -75.8125, -25.1406],
[ 24.9062, 9.4609, 20.9062, ..., -19.0469, 45.0312, 5.7617],
[-51.4375, -98.1250, -31.6406, ..., -17.2969, 14.1328, -29.6562],
...,
[ -2.1914, 5.2891, -62.9062, ..., -11.7344, 17.3594, 11.5156],
[-10.4766, 12.7812, -11.8047, ..., 70.8125, -16.6406, -8.8594],
[ 29.0156, 22.3438, 28.6250, ..., 1.4102, -22.1719, 8.0547]],
device='cuda:0')
torch_output=tensor([[-13.5221, -28.3083, 18.6528, ..., 76.4760, -75.8272, -25.1379],
[ 24.9082, 9.4600, 20.9104, ..., -19.0470, 45.0211, 5.7626],
[-51.4403, -98.1155, -31.6342, ..., -17.2929, 14.1310, -29.6568],
...,
[ -2.1921, 5.2877, -62.8914, ..., -11.7347, 17.3666, 11.5192],
[-10.4733, 12.7849, -11.8028, ..., 70.8333, -16.6389, -8.8601],
[ 29.0169, 22.3487, 28.6278, ..., 1.4103, -22.1648, 8.0536]])
numpy_output=tensor([[-13.5221, -28.3082, 18.6528, ..., 76.4760, -75.8272, -25.1380],
[ 24.9082, 9.4600, 20.9104, ..., -19.0470, 45.0211, 5.7626],
[-51.4403, -98.1155, -31.6342, ..., -17.2930, 14.1310, -29.6568],
...,
[ -2.1921, 5.2877, -62.8914, ..., -11.7347, 17.3666, 11.5192],
[-10.4733, 12.7849, -11.8028, ..., 70.8333, -16.6389, -8.8601],
[ 29.0169, 22.3487, 28.6278, ..., 1.4103, -22.1648, 8.0536]])
naive_output=tensor([[-13.5221, -28.3082, 18.6528, ..., 76.4760, -75.8271, -25.1380],
[ 24.9082, 9.4600, 20.9104, ..., -19.0470, 45.0211, 5.7626],
[-51.4403, -98.1155, -31.6342, ..., -17.2929, 14.1310, -29.6568],
...,
[ -2.1921, 5.2877, -62.8914, ..., -11.7347, 17.3666, 11.5192],
[-10.4733, 12.7849, -11.8028, ..., 70.8333, -16.6389, -8.8601],
[ 29.0169, 22.3487, 28.6278, ..., 1.4103, -22.1648, 8.0536]])
gemm config | vs. triton | vs. torch gpu | vs. numpy |
---|---|---|---|
128x128x128 | ❌ | ✅ | ✅ |
256x256x256 | ❌ | ✅ | ✅ |
512x512x512 | ❌ | ✅ | ✅ |
1024x1024x1024 | ❌ | ✅ | ✅ |
gemm config | vs. triton | vs. torch gpu | vs. numpy |
---|---|---|---|
128x128x128 | ❌ | ✅ | ✅ |
256x256x256 | ❌ | ✅ | ✅ |
512x512x512 | ❌ | ✅ | ✅ |
1024x1024x1024 | ❌ | ✅ | ✅ |
Disabling L2 Cache Optimizations
didn't help to solve the problem - i.e.,
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
#num_pid_in_group = GROUP_SIZE_M * num_pid_n
#group_id = pid // num_pid_in_group
#first_pid_m = group_id * GROUP_SIZE_M
#group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
#pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
#pid_n = (pid % num_pid_in_group) // group_size_m
pid_m = pid // num_pid_m
pid_n = pid % num_pid_m
I created my matrix multiplication and used it on the CPU using triton_viz and I had a problem with the precision of float16 and not float32:
@triton.jit
def mat_mult_kernel(A, B, C, M, N, K, BLOCK_SIZE: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_am = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_bn = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_k = tl.arange(0, BLOCK_SIZE)
#print(offs_am)
acc = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE):
a_ptr = offs_am[:, None] * K + (offs_k[None, :] + k)
a_mask = (offs_am[:, None] < M) & ((offs_k[None, :] + k) < K)
a = tl.load(A + a_ptr, mask=a_mask, other=0.0)
b_ptr = (offs_k[:, None] + k) * N + offs_bn[None, :]
b_mask = ((offs_k[:, None] + k) < K) & (offs_bn[None, :] < N)
b = tl.load(B + b_ptr, mask=b_mask, other=0.0)
acc += tl.dot(a, b)
mask_m = offs_am[:, None] < M
mask_n = offs_bn[None, :] < N
tl.store(C + offs_am[:, None] * N + offs_bn[None, :], acc, mask=mask_m & mask_n)
def mat_mult():
M = 4
N = 4
K = 4
BLOCK_SIZE = 16
# Inicializando as matrizes e transferindo para GPU
A = torch.randn((M, K), dtype=torch.float32)
B = torch.randn((K, N), dtype=torch.float32)
C = torch.zeros((M, N), dtype=torch.float32)
C_real = A @ B
# Definindo grid
grid = (triton.cdiv(M, BLOCK_SIZE), triton.cdiv(N, BLOCK_SIZE))
triton_viz.trace(mat_mult_kernel)[grid](A, B, C, M, N, K, BLOCK_SIZE, num_warps=4)
triton_viz.launch()
print(C)
print(C_real)
print(f'Is the same result? {"✅" if torch.allclose(C, C_real, atol=1e-2) else "❌"}')
mat_mult()
I spotted the problem. Regarding fp32
, it is the following line
c = accumulator.to(tl.float16)
I can see that the generated Triton code results in inaccurate numerics for both supported backends - i.e.,
cuda
andhip
.The original,
03-matrix-multiplication.py
compares numerics usingtorch.gpu
andtriton
; usingfloat16
. I looked through thetest_cast_matmul.py
regression test and I observed that both the relative and absolute tolerances are set to very high values - i.e.,1e-2
and0.3
, respectively.https://github.com/triton-lang/triton/blob/810e04611f22970aacdb49da1fcda6d0ac909216/python/test/regression/test_cast_matmul.py#L97
This prevents the CI pipeline from failing.
I changed the code to compare numerical results using two CPU implementations - i.e., 1)
numpy
and 2) a naive (3 nested loops) GEMM implementation. As one can see, the results obtained with Triton significantly deviates from the baseline (i.e., 3 nested loops impl.). I checked the Triton GEMM kernel and it seemed okayish. I suspect that the problem is with the generated code.Used Triton commit:
c7a37a9d6
Modified GEMM Example
MI300 Results (fp32)
H100 Results (fp32)