triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.08k stars 1.44k forks source link

`03-matrix-multiplication.py` results in inaccurate numerics on H100 and MI300 using `float32` #4202

Closed ravil-mobile closed 3 weeks ago

ravil-mobile commented 1 month ago

I can see that the generated Triton code results in inaccurate numerics for both supported backends - i.e., cuda and hip.

The original, 03-matrix-multiplication.py compares numerics using torch.gpu and triton; using float16. I looked through the test_cast_matmul.py regression test and I observed that both the relative and absolute tolerances are set to very high values - i.e., 1e-2 and 0.3, respectively.

https://github.com/triton-lang/triton/blob/810e04611f22970aacdb49da1fcda6d0ac909216/python/test/regression/test_cast_matmul.py#L97

This prevents the CI pipeline from failing.

I changed the code to compare numerical results using two CPU implementations - i.e., 1) numpy and 2) a naive (3 nested loops) GEMM implementation. As one can see, the results obtained with Triton significantly deviates from the baseline (i.e., 3 nested loops impl.). I checked the Triton GEMM kernel and it seemed okayish. I suspect that the problem is with the generated code.

Used Triton commit: c7a37a9d6

Modified GEMM Example

import torch
import triton
import triton.language as tl
import numpy as np
from numba import jit as numba_jit
from collections import OrderedDict
import argparse

parser = argparse.ArgumentParser(
  prog="GEMM validation",
  description="",
  allow_abbrev=False,
)

parser.add_argument("-t", "--type", choices=['fp16', 'fp32'],
                    default='fp32',
                    help="test data type")

parser.add_argument("--torch-cpu", action='store_false',
                    help="use CPU for Torch")

parser.add_argument("--rtol", type=float, default=1e-5,
                    help="relative tolerance")

parser.add_argument("--atol", type=float, default=1e-2,
                    help="absolute tolerance")

args = parser.parse_args()

def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"

def is_hip_mi200():
    target = triton.runtime.driver.active.get_current_target()
    return target.backend == 'hip' and target.arch == 'gfx90a'

def device_fp_type():
    if args.type == 'fp16':
        return torch.float16
    elif args.type == 'fp32':
        return torch.float32
    else:
        raise RuntimeError(f'`{args.type}` type is not supported')

def host_fp_type():
    return np.float32

def get_cuda_autotune_config():
    return [
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
                      num_warps=2),
        # Good config for fp8 inputs.
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
                      num_warps=4)
    ]

def get_hip_autotune_config():
    return [
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
            num_warps=4, num_stages=0),
        triton.Config(
            {'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2},
            num_warps=8, num_stages=0),
        triton.Config(
            {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2},
            num_warps=8, num_stages=0),
        triton.Config(
            {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'waves_per_eu': 3},
            num_warps=4, num_stages=0),
        triton.Config(
            {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 1, 'waves_per_eu': 8},
            num_warps=4, num_stages=0),
    ]

def get_autotune_config():
    if is_cuda():
        return get_cuda_autotune_config()
    else:
        return get_hip_autotune_config()

# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of `triton.Config` objects that define different configurations of
#       meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
#   - An auto-tuning *key* whose change in values will trigger evaluation of all the
#       provided configs
@triton.autotune(
    configs=get_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
        # Pointers to matrices
        a_ptr, b_ptr, c_ptr,
        # Matrix dimensions
        M, N, K,
        # The stride variables represent how much to increase the ptr by when moving by 1
        # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
        # by to get the element one row down (A has M rows).
        stride_am, stride_ak,  #
        stride_bk, stride_bn,  #
        stride_cm, stride_cn,
        # Meta-parameters
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  #
        GROUP_SIZE_M: tl.constexpr,  #
        ACTIVATION: tl.constexpr  #
):
    """Kernel for computing the matmul C = A x B.
    A has shape (M, K), B has shape (K, N) and C has shape (M, N)
    """
    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    # ----------------------------------------------------------
    # Create pointers for the first blocks of A and B.
    # We will advance this pointer as we move in the K direction
    # and accumulate
    # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
    # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
    # See above `Pointer Arithmetic` section for details
    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # -----------------------------------------------------------
    # Iterate to compute a block of the C matrix.
    # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
    # of fp32 values for higher accuracy.
    # `accumulator` will be converted back to fp16 after the loop.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # Load the next block of A and B, generate a mask by checking the K dimension.
        # If it is out of bounds, set it to 0.
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # We accumulate along the K dimension.
        accumulator = tl.dot(a, b, accumulator)
        # Advance the ptrs to the next K block.
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk
    # You can fuse arbitrary activation functions here
    # while the accumulator is still in FP32!
    if ACTIVATION == "leaky_relu":
        accumulator = leaky_relu(accumulator)
    c = accumulator.to(tl.float16)

    # -----------------------------------------------------------
    # Write back the block of the output matrix C with masks.
    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)

# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
    return tl.where(x >= 0, x, 0.01 * x)

# %%
# We can now create a convenience wrapper function that only takes two input tensors,
# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.

def matmul(a, b, activation=""):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=device_fp_type())
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation  #
    )
    return c

@numba_jit
def naive_gemm(a, b, c):
  M, K = a.shape
  K, N = b.shape
  for m in range(M):
    for n in range(N):
      acc = 0.0
      for k in range(K):
          acc += a[m][k] * b[k][n]
      c[m][n] = acc

  return c

# %%
# Unit Test
# ---------
#
# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).

torch.manual_seed(0)

is_torch_gpu = True if args.torch_cpu else False
#is_torch_gpu = False

results = OrderedDict()
test_sizes = [128, 256, 512, 1024]
for S in test_sizes:
    M = S; N = S; K = S
    print(f"M=N=K : {S}")
    print(f"{args.atol=}")
    print(f"{args.rtol=}")

    a = torch.randn((M, K), device='cuda', dtype=device_fp_type())
    b = torch.randn((K, N), device='cuda', dtype=device_fp_type())
    triton_output = matmul(a, b)
    torch_output = torch.matmul(a if is_torch_gpu else a.cpu(), b if is_torch_gpu else b.cpu())
    print(f"triton_output={triton_output}")
    print(f"torch_output={torch_output}")
    # Bigger tolerance for AMD MI200 devices.
    # MI200 devices use reduced precision fp16 and bf16 and flush input and
    # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices

    a_numpy = a.cpu().numpy().astype(host_fp_type())
    b_numpy = b.cpu().numpy().astype(host_fp_type())
    numpy_output = np.matmul(a_numpy, b_numpy)
    numpy_output_tensor = torch.from_numpy(numpy_output).type(device_fp_type())
    print(f"numpy_output={numpy_output_tensor}")

    c_naive = np.zeros([M, N], dtype=host_fp_type())
    naive_output = naive_gemm(a_numpy, b_numpy, c_naive)
    naive_output_tensor = torch.from_numpy(naive_output).type(device_fp_type())
    print(f"naive_output={naive_output_tensor}")

    rtol = 1e-2 if is_hip_mi200() else args.rtol
    if torch.allclose(triton_output.cpu(), torch_output.cpu(), atol=args.atol, rtol=rtol):
      print("✅ Triton and Torch match")
    else:
      print("❌ Triton and Torch differ")

    key = f'{M}x{N}x{K}'
    results[key] = OrderedDict()
    vs_triton = torch.allclose(naive_output_tensor, triton_output.cpu(), atol=args.atol, rtol=rtol)
    vs_torch = torch.allclose(naive_output_tensor, torch_output.cpu(), atol=args.atol, rtol=rtol)
    vs_numpy = torch.allclose(naive_output_tensor, numpy_output_tensor, atol=args.atol, rtol=rtol)

    results[key][f"vs. triton"] = vs_triton
    results[key][f"vs. torch {'gpu' if is_torch_gpu else 'cpu'}"] = vs_torch
    results[key][f"vs. numpy"] = vs_numpy

    print(f"vs_triton={'✅' if vs_triton else '❌'}")
    print(f"vs_torch_{'gpu' if is_torch_gpu else 'cpu'}={'✅' if vs_torch else '❌'}")
    print(f"vs_numpy={'✅' if vs_numpy else '❌'}")
    print('-'*80)

row_names = list(results.keys())
num_rows = len(row_names)

first_key = row_names[0]
column_names = list(results[first_key].keys())

header = "| gemm config | "
sep = "| ----------- | "
for name in column_names:
  header += f" {name} |"
  sep += f" {'-'*len(name)} |"
print(header)
print(sep)

for key in results.keys():
  line = f"| {key} |"
  for item in results[key].values():
    symbol = '✅' if item else '❌'
    line += f" {symbol} |"
  print(line)

MI300 Results (fp32)

python3 ./03-gemm.py -t fp32 --torch-cpu --rtol=1e-4 --atol=1e-2
gemm config vs. triton vs. torch cpu vs. numpy
128x128x128
256x256x256
512x512x512
1024x1024x1024

H100 Results (fp32)

python3 ./03-gemm.py -t fp32 --torch-cpu --rtol=1e-4 --atol=1e-2
gemm config vs. triton vs. torch cpu vs. numpy
128x128x128
256x256x256
512x512x512
1024x1024x1024
ravil-mobile commented 1 month ago

Here are some intermediate results

H100

M=N=K : 1024
args.atol=0.01
args.rtol=0.0001
triton_output=tensor([[  1.2031,  37.3750, -29.4062,  ...,  43.4688,  93.7500, -15.1328],
        [-11.3516,  23.7031,  -3.3965,  ...,  39.2812, -15.3438, -79.2500],
        [ -4.6289,  30.0000, -42.2500,  ...,  -5.1094,  53.6875,  16.0000],
        ...,
        [-85.3125,   3.0156,  20.2031,  ..., -30.9844, -30.5156, -44.1875],
        [ 44.8438, -21.9688,  10.4141,  ...,  72.0000,   6.9492,  -1.6816],
        [-37.9375,  -8.3750,  41.1250,  ..., -12.8750,  38.2188, -22.5469]],
       device='cuda:0')
torch_output=tensor([[  1.2027,  37.4007, -29.4260,  ...,  43.4901,  93.8217, -15.1370],
        [-11.3525,  23.7217,  -3.4030,  ...,  39.3042, -15.3559, -79.3293],
        [ -4.6234,  30.0324, -42.3040,  ...,  -5.1261,  53.7173,  16.0252],
        ...,
        [-85.3604,   3.0140,  20.1974,  ..., -30.9894, -30.5337, -44.1995],
        [ 44.8633, -21.9839,  10.4117,  ...,  72.0516,   6.9408,  -1.6693],
        [-37.9492,  -8.3891,  41.1487,  ..., -12.8913,  38.2381, -22.5666]])
numpy_output=tensor([[  1.2026,  37.4007, -29.4260,  ...,  43.4901,  93.8217, -15.1370],
        [-11.3525,  23.7217,  -3.4030,  ...,  39.3042, -15.3559, -79.3293],
        [ -4.6234,  30.0324, -42.3040,  ...,  -5.1262,  53.7173,  16.0252],
        ...,
        [-85.3604,   3.0140,  20.1974,  ..., -30.9894, -30.5337, -44.1995],
        [ 44.8633, -21.9839,  10.4117,  ...,  72.0516,   6.9408,  -1.6693],
        [-37.9492,  -8.3891,  41.1487,  ..., -12.8913,  38.2381, -22.5666]])
naive_output=tensor([[  1.2027,  37.4007, -29.4260,  ...,  43.4901,  93.8217, -15.1370],
        [-11.3525,  23.7217,  -3.4030,  ...,  39.3042, -15.3559, -79.3293],
        [ -4.6234,  30.0324, -42.3040,  ...,  -5.1261,  53.7173,  16.0252],
        ...,
        [-85.3604,   3.0140,  20.1974,  ..., -30.9894, -30.5337, -44.1995],
        [ 44.8633, -21.9839,  10.4117,  ...,  72.0516,   6.9408,  -1.6693],
        [-37.9492,  -8.3891,  41.1487,  ..., -12.8913,  38.2381, -22.5667]])

MI300

M=N=K : 1024
args.atol=0.01
args.rtol=0.0001
triton_output=tensor([[-13.5234, -28.3125,  18.6562,  ...,  76.5000, -75.8125, -25.1406],
        [ 24.9062,   9.4609,  20.9062,  ..., -19.0469,  45.0312,   5.7617],
        [-51.4375, -98.1250, -31.6406,  ..., -17.2969,  14.1328, -29.6562],
        ...,
        [ -2.1914,   5.2891, -62.9062,  ..., -11.7344,  17.3594,  11.5156],
        [-10.4766,  12.7812, -11.8047,  ...,  70.8125, -16.6406,  -8.8594],
        [ 29.0156,  22.3438,  28.6250,  ...,   1.4102, -22.1719,   8.0547]],
       device='cuda:0')
torch_output=tensor([[-13.5221, -28.3083,  18.6528,  ...,  76.4760, -75.8272, -25.1379],
        [ 24.9082,   9.4600,  20.9104,  ..., -19.0470,  45.0211,   5.7626],
        [-51.4403, -98.1155, -31.6342,  ..., -17.2929,  14.1310, -29.6568],
        ...,
        [ -2.1921,   5.2877, -62.8914,  ..., -11.7347,  17.3666,  11.5192],
        [-10.4733,  12.7849, -11.8028,  ...,  70.8333, -16.6389,  -8.8601],
        [ 29.0169,  22.3487,  28.6278,  ...,   1.4103, -22.1648,   8.0536]])
numpy_output=tensor([[-13.5221, -28.3082,  18.6528,  ...,  76.4760, -75.8272, -25.1380],
        [ 24.9082,   9.4600,  20.9104,  ..., -19.0470,  45.0211,   5.7626],
        [-51.4403, -98.1155, -31.6342,  ..., -17.2930,  14.1310, -29.6568],
        ...,
        [ -2.1921,   5.2877, -62.8914,  ..., -11.7347,  17.3666,  11.5192],
        [-10.4733,  12.7849, -11.8028,  ...,  70.8333, -16.6389,  -8.8601],
        [ 29.0169,  22.3487,  28.6278,  ...,   1.4103, -22.1648,   8.0536]])
naive_output=tensor([[-13.5221, -28.3082,  18.6528,  ...,  76.4760, -75.8271, -25.1380],
        [ 24.9082,   9.4600,  20.9104,  ..., -19.0470,  45.0211,   5.7626],
        [-51.4403, -98.1155, -31.6342,  ..., -17.2929,  14.1310, -29.6568],
        ...,
        [ -2.1921,   5.2877, -62.8914,  ..., -11.7347,  17.3666,  11.5192],
        [-10.4733,  12.7849, -11.8028,  ...,  70.8333, -16.6389,  -8.8601],
        [ 29.0169,  22.3487,  28.6278,  ...,   1.4103, -22.1648,   8.0536]])
ravil-mobile commented 1 month ago

Using Torch-GPU

H100

gemm config vs. triton vs. torch gpu vs. numpy
128x128x128
256x256x256
512x512x512
1024x1024x1024

MI300

gemm config vs. triton vs. torch gpu vs. numpy
128x128x128
256x256x256
512x512x512
1024x1024x1024
ravil-mobile commented 1 month ago

Disabling L2 Cache Optimizations didn't help to solve the problem - i.e.,

    # -----------------------------------------------------------
    # Map program ids `pid` to the block of C it should compute.
    # This is done in a grouped ordering to promote L2 data reuse.
    # See above `L2 Cache Optimizations` section for details.
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    #num_pid_in_group = GROUP_SIZE_M * num_pid_n
    #group_id = pid // num_pid_in_group
    #first_pid_m = group_id * GROUP_SIZE_M
    #group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    #pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    #pid_n = (pid % num_pid_in_group) // group_size_m
    pid_m = pid // num_pid_m
    pid_n = pid % num_pid_m
IsaacMirandaCamargos commented 1 month ago

I created my matrix multiplication and used it on the CPU using triton_viz and I had a problem with the precision of float16 and not float32:

@triton.jit
def mat_mult_kernel(A, B, C, M, N, K, BLOCK_SIZE: tl.constexpr):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_am = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offs_bn = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offs_k = tl.arange(0, BLOCK_SIZE)

    #print(offs_am)

    acc = tl.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=tl.float32)

    for k in range(0, K, BLOCK_SIZE):
        a_ptr = offs_am[:, None] * K + (offs_k[None, :] + k)
        a_mask = (offs_am[:, None] < M) & ((offs_k[None, :] + k) < K)
        a = tl.load(A + a_ptr, mask=a_mask, other=0.0)

        b_ptr = (offs_k[:, None] + k) * N + offs_bn[None, :]
        b_mask = ((offs_k[:, None] + k) < K) & (offs_bn[None, :] < N)
        b = tl.load(B + b_ptr, mask=b_mask, other=0.0)
        acc += tl.dot(a, b)

    mask_m = offs_am[:, None] < M
    mask_n = offs_bn[None, :] < N
    tl.store(C + offs_am[:, None] * N + offs_bn[None, :], acc, mask=mask_m & mask_n)
def mat_mult():
    M = 4
    N = 4
    K = 4
    BLOCK_SIZE = 16

    # Inicializando as matrizes e transferindo para GPU
    A = torch.randn((M, K), dtype=torch.float32)
    B = torch.randn((K, N), dtype=torch.float32)
    C = torch.zeros((M, N), dtype=torch.float32)
    C_real = A @ B

    # Definindo grid
    grid = (triton.cdiv(M, BLOCK_SIZE), triton.cdiv(N, BLOCK_SIZE))
    triton_viz.trace(mat_mult_kernel)[grid](A, B, C, M, N, K, BLOCK_SIZE, num_warps=4)
    triton_viz.launch()

    print(C)
    print(C_real)
    print(f'Is the same result? {"✅" if torch.allclose(C, C_real, atol=1e-2) else "❌"}')

mat_mult()
ravil-mobile commented 3 weeks ago

I spotted the problem. Regarding fp32, it is the following line

  c = accumulator.to(tl.float16)