triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.47k stars 1.66k forks source link

fp32 addmm returns incorrect result on A100 #5204

Open davidberard98 opened 1 day ago

davidberard98 commented 1 day ago

Describe the bug

pytorch reference: https://github.com/pytorch/pytorch/issues/141079

repro:

import torch

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton.jit
def triton_mm(in_ptr0, arg_A, arg_B, out_ptr0, bias_ptr):
    GROUP_M : tl.constexpr = 8
    EVEN_K : tl.constexpr = False
    ALLOW_TF32 : tl.constexpr = False
    # ALLOW_TF32 : tl.constexpr = True
    ACC_TYPE : tl.constexpr = tl.float32
    B_PROLOGUE_CAST_TYPE : tl.constexpr = None
    BLOCK_M : tl.constexpr = 64
    BLOCK_N : tl.constexpr = 64
    BLOCK_K : tl.constexpr = 64
    A = arg_A
    B = arg_B

    M = 512
    N = 34
    K = 33
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    stride_am = 33
    stride_ak = 1
    stride_bk = 1
    stride_bn = 33

    # based on triton.ops.matmul
    pid = tl.program_id(0)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
        ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    else:
        ram = rm % M
    if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
        rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    else:
        rbn = rn % N
    rk = tl.arange(0, BLOCK_K)
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        if B_PROLOGUE_CAST_TYPE is not None:
            b = b.to(B_PROLOGUE_CAST_TYPE)
        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + (34*idx_m)
    tmp0 = tl.load(in_ptr0 + (tl.broadcast_to(idx_n, acc.shape)), mask, eviction_policy='evict_last')
    # tl.store(bias_ptr + tl.broadcast_to(xindex, acc.shape), tmp0, mask)
    tmp1 = acc + tmp0
    tl.store(out_ptr0 + (tl.broadcast_to(xindex, acc.shape)), tmp1, mask)
    # tl.store(bias_ptr + tl.broadcast_to(xindex, acc.shape), tmp0, mask)

BLOCK_M : tl.constexpr = 64
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 64
dtype = torch.float32

A = torch.ones(512, 33, device="cuda", dtype=dtype)
B = torch.ones(33, 34, device="cuda", dtype=dtype)
bias = torch.ones(34, device="cuda", dtype=dtype)
out = torch.empty(512, 34, device="cuda", dtype=dtype)
bias_out = torch.empty(512, 34, device="cuda", dtype=dtype)

# ret = triton_mm[(512 // 64,)](bias, A, B, out, bias_out, debug=True)
ret = triton_mm[(512 // 64,)](bias, A, B, out, bias_out)

'''
from triton.tools.disasm import get_sass
sass = get_sass(ret.asm['cubin'])
with open("sass.sass", "w") as f:
    f.write(sass)
'''

expect = A @ B + bias
# breakpoint()
print((expect-out).abs().max())
assert (expect-out).abs().max().item() < 0.01
# v = ((expect-out).abs() > 0.01).to(torch.int32).tolist()
# print("\n".join([str(x) for x in v]))

Details:

We think this could be a ptxas bug.

For triton 3.2, we're considering reverting https://github.com/triton-lang/triton/pull/4582 in the release branch to work around the issue (cc @bertmaher)

cc @embg who helped with debugging.

Environment details

Triton: main branch nov 19 GPU: A100

davidberard98 commented 1 day ago

@bertmaher https://github.com/davidberard98/triton/tree/revert-4582 is the revert commit for convenience, since the revert had a conflict. I verified that it fixes this issue, but I haven't run any broader inductor/triton tests.

bertmaher commented 1 day ago

cc @ThomasRaoux since #4582 is yours, and we're a bit stumped as to whether there's an issue with that PR or if this is just uncovering a ptxas bug

ThomasRaoux commented 1 day ago

huh yes I'm very surprised this would cause such a failure.

Confusingly, when we set init=false here,

This does suggest a problem in ptxas as this init is there only to help ptxas figure the liverness of this register.

ptxas optimization level of 0 or 1 fixes the issue.

I guess that's another hint in this direction.

Should we file a bug against ptxas? Not sure what else to do here

bertmaher commented 22 hours ago

Should we file a bug against ptxas? Not sure what else to do here

Yeah, I think @davidberard98 has filed an nvbug for ptxas (is there a shareable link for this one?).

I'm thinking we should revert the PR for the 3.2 release since it causes a few failures for PyTorch, but leave main alone and just see what happens with ptxas. How does that sound?

bertmaher commented 12 hours ago

Reverted on the rc/3.2.x branch as https://github.com/triton-lang/triton/commit/35c6c7c6284582b3f41c71c150e11b517acf074a