ROCm / triton

Development repository for the Triton language and compiler
MIT License
89 stars 27 forks source link

[Upstream backend] [PyTorch UT]: `Callback: Queue 0x7ef32c400000 aborting with error : HSA_STATUS_ERROR_EXCEPTION: An HSAIL operation resulted in a hardware exception. code: 0x1016` #559

Closed jataylo closed 5 months ago

jataylo commented 5 months ago

Problem Description

Triton branch: upstream

This change https://github.com/openai/triton/commit/ebb065f2e54d55a1f7ef80e1e0dababed8991f67 has caused all inductor UTs to fail

Triton level reproducer:

from ctypes import c_void_p, c_long
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch import device
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor import triton_helpers
from torch._inductor.triton_helpers import libdevice, math as tl_math

import triton
import triton.language as tl
from torch._dynamo.testing import rand_strided

#@triton.jit
#def _any_combine(a, b):
#    return a | b

#@triton.jit
#def triton_any(a, dim):
#    return tl.reduce(a, dim, _any_combine)

@triton.jit
def triton_fn(in_out_ptr0, in_out_ptr1, in_ptr0, out_ptr0, out_ptr1, xnumel, rnumel):
    XBLOCK: tl.constexpr = 64
    xnumel = 1
    rnumel = 64
    #RBLOCK: tl.constexpr = 64 
    #xoffset = tl.program_id(0) * XBLOCK
    #xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    #xmask = xindex < xnumel
    #rindex = tl.arange(0, RBLOCK)[None, :]
    #roffset = 0
    #rmask = rindex < rnumel
    #r0 = rindex
    #tmp0 = tl.load(in_ptr0 + (r0), rmask, other=0.0)
    #tmp1 = (tmp0 != 0)
    #tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
    #tmp4 = tl.where(rmask, tmp2, 0)
    #tmp5 = triton_any(tmp4, 1)[:, None]
    #tmp6 = libdevice.isinf(tmp0).to(tl.int1)
    #tmp7 = tl.broadcast_to(tmp6, [XBLOCK, RBLOCK])
    #tmp9 = tl.where(rmask, tmp7, 0)
    #tmp10 = triton_any(tmp9, 1)[:, None]
    #tmp11 = tmp6 == 0
    #tmp12 = tmp11.to(tl.int64)
    #tmp13 = (tmp12 != 0)
    #tmp14 = tl.broadcast_to(tmp13, [XBLOCK, RBLOCK])
    #tmp16 = tl.where(rmask, tmp14, 0)
    #tmp17 = triton_any(tmp16, 1)[:, None]
    #tmp18 = tmp11 == 0
    #tmp19 = tl.broadcast_to(tmp18, [XBLOCK, RBLOCK])
    #tmp21 = tl.where(rmask, tmp19, 0)
    #tmp22 = triton_any(tmp21, 1)[:, None]
    #tmp23 = tmp17 == 0
    #tmp24 = tmp22 == 0
    #tl.debug_barrier()
    #tl.store(in_out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp23, None)
    #tl.debug_barrier()
    #tl.store(in_out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp24, None)
    #tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp5, None)
    #tl.store(out_ptr1 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp10, None)

from torch import empty_strided
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda

arg0_1 = rand_strided((64, ), (1, ), device='cuda:0', dtype=torch.float32)
buf0 = empty_strided_cuda((), (), torch.bool)
buf1 = empty_strided_cuda((), (), torch.bool)
buf2 = empty_strided_cuda((), (), torch.bool)
buf3 = empty_strided_cuda((), (), torch.bool)

src = triton.compiler.ASTSource(fn=triton_fn, signature="*i1, *i1, *fp32, *i1, *i1, i32, i32")
test = triton.compile(src)
test[(1,1,1)](buf2, buf3, arg0_1, buf0, buf1, 1, 64)

Reproducer:

Operating System

-

CPU

-

GPU

AMD Instinct MI250X

ROCm Version

ROCm 6.0.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

zhanglx13 commented 5 months ago

This is due to old rocm version 5.3 Switching to rocm 5.7 resolved the issue