ROCm / triton

Development repository for the Triton language and compiler
MIT License
80 stars 22 forks source link

[Upstream Backend] [PyTorch UT] `error: failed to legalize operation 'tt.mulhiui' that was explicitly marked illegal` #548

Closed jataylo closed 2 months ago

jataylo commented 3 months ago

Problem Description

Environment: Docker image: rocm/pytorch-private:rocm_inductor_triton_upstream_migration_v1 Triton branch: https://github.com/jataylo/triton/tree/jack-triton-inductor-migration Pytorch branch: https://github.com/pytorch/pytorch/tree/rocm-inductor-hip-device

PyTorch UT: inductor/test_torchinductor.py::test_dropout2_cuda

Reproducer:

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor
from torch._inductor.triton_helpers import libdevice, math as tl_math

@triton.jit
def randint64(seed, offset, low, high):
    r0, r1, r2, r3 = tl.randint4x(seed, offset)
    r0 = r0.to(tl.uint64)
    r1 = r1.to(tl.uint64)
    result = r0 | (r1 << 32)
    size = high - low
    result = result % size.to(tl.uint64)
    result = result.to(tl.int64) + low
    return result

@triton.jit
def triton_fn(in_ptr0, in_ptr1, out_ptr0, ks0, load_seed_offset, load_seed_offset1):
    XBLOCK : tl.constexpr = 16
    RBLOCK : tl.constexpr = 8192
    xnumel : tl.constexpr = 13
    rnumel : tl.constexpr = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp27 = tl.full([XBLOCK, RBLOCK], 0, tl.int32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = r1 + (x0*((12 + ks0) // 13))
        tmp1 = ks0
        tmp2 = tmp0 < tmp1
        tmp3 = tl.load(in_ptr0 + load_seed_offset)
        tmp4 = randint64(tmp3, (tmp0).to(tl.uint32), -2147483648, 2147483648)
        tmp5 = tmp4.to(tl.int32)
        tmp6 = tl_math.abs(tmp5)
        tmp7 = tl.full([1, 1], 2147483647, tl.int32)
        tmp8 = tmp6 % tmp7
        tmp9 = tmp8 + tmp7
        tmp10 = tl.where(((tmp8 != 0) & ((tmp8 < 0) != (tmp7 < 0))), tmp9, tmp8)
        tmp11 = tl.full([1, 1], 1, tl.int32)
        tmp12 = tmp10 + tmp11
        tmp13 = tmp12.to(tl.int64)
        tmp14 = tl.load(in_ptr1 + (r1 + (x0*((12 + ks0) // 13))), rmask & tmp2 & xmask, eviction_policy='evict_first', other=0.0)
        tmp15 = tmp14.to(tl.int64)
        tmp16 = tmp13 * tmp15
        tmp17 = tl.load(in_ptr0 + load_seed_offset1)
        tmp18 = randint64(tmp17, (tmp0).to(tl.uint32), -2147483648, 2147483648)
        tmp19 = tmp18.to(tl.int32)
        tmp20 = tl_math.abs(tmp19)
        tmp21 = tmp20.to(tl.int64)
        tmp22 = tmp16 + tmp21
        tmp23 = tmp22.to(tl.int32)
        tmp24 = tl.full(tmp23.shape, 0, tmp23.dtype)
        tmp25 = tl.where(tmp2, tmp23, tmp24)
        tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
        tmp28 = _tmp27 ^ tmp26
        _tmp27 = tl.where(rmask & xmask, tmp28, _tmp27)
    tmp27 = tl.xor_sum(_tmp27, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp27, xmask)

src = triton.compiler.ASTSource(fn=triton_fn, signature="*i64, *i32, *i32, i32, i32, i32")
test = triton.compile(src)

Stacktrace:

loc(callsite(callsite(callsite(callsite("/tmp/triton/python/triton/language/random.py":35:28 at "/tmp/triton/python/triton/language/random.py":61:57) at "/tmp/triton/python/triton/language/random.py":94:44) at "repro.py":10:40) at "repro.py":39:66)): error: failed to legalize operation 'tt.mulhiui' that was explicitly marked illegal
Traceback (most recent call last):
  File "repro.py", line 68, in <module>
    test = triton.compile(src)
  File "/tmp/triton/python/triton/compiler/compiler.py", line 268, in compile
    next_module = compile_ir(module, metadata)
  File "/tmp/triton/python/triton/backends/amd/compiler.py", line 223, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, 90)
  File "/tmp/triton/python/triton/backends/amd/compiler.py", line 163, in make_llir
    pm.run(mod)
RuntimeError: PassManager::run failed

Operating System

-

CPU

-

GPU

AMD Instinct MI250X

ROCm Version

ROCm 6.0.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

jataylo commented 3 months ago

Stripped out inductor entirely:

import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

@triton.jit
def randint64(seed, offset, low, high):
    r0, r1, r2, r3 = tl.randint4x(seed, offset)
    r0 = r0.to(tl.uint64)
    r1 = r1.to(tl.uint64)
    result = r0 | (r1 << 32)
    size = high - low
    result = result % size.to(tl.uint64)
    result = result.to(tl.int64) + low
    return result

@triton.jit
def triton_fn(in_ptr0, in_ptr1, out_ptr0, ks0, load_seed_offset, load_seed_offset1):
    XBLOCK : tl.constexpr = 16
    RBLOCK : tl.constexpr = 8192
    xnumel : tl.constexpr = 13
    rnumel : tl.constexpr = 1
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rbase = tl.arange(0, RBLOCK)[None, :]
    x0 = xindex
    _tmp27 = tl.full([XBLOCK, RBLOCK], 0, tl.int32)
    for roffset in range(0, rnumel, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r1 = rindex
        tmp0 = r1 + (x0*((12 + ks0) // 13))
        tmp1 = ks0
        tmp2 = tmp0 < tmp1
        tmp3 = tl.load(in_ptr0 + load_seed_offset)
        tmp4 = randint64(tmp3, (tmp0).to(tl.uint32), -2147483648, 2147483648)
        tmp5 = tmp4.to(tl.int32)
        tmp6 = tl.math.abs(tmp5)
        tmp7 = tl.full([1, 1], 2147483647, tl.int32)
        tmp8 = tmp6 % tmp7
        tmp9 = tmp8 + tmp7
        tmp10 = tl.where(((tmp8 != 0) & ((tmp8 < 0) != (tmp7 < 0))), tmp9, tmp8)
        tmp11 = tl.full([1, 1], 1, tl.int32)
        tmp12 = tmp10 + tmp11
        tmp13 = tmp12.to(tl.int64)
        tmp14 = tl.load(in_ptr1 + (r1 + (x0*((12 + ks0) // 13))), rmask & tmp2 & xmask, eviction_policy='evict_first', other=0.0)
        tmp15 = tmp14.to(tl.int64)
        tmp16 = tmp13 * tmp15
        tmp17 = tl.load(in_ptr0 + load_seed_offset1)
        tmp18 = randint64(tmp17, (tmp0).to(tl.uint32), -2147483648, 2147483648)
        tmp19 = tmp18.to(tl.int32)
        tmp20 = tl.math.abs(tmp19)
        tmp21 = tmp20.to(tl.int64)
        tmp22 = tmp16 + tmp21
        tmp23 = tmp22.to(tl.int32)
        tmp24 = tl.full(tmp23.shape, 0, tmp23.dtype)
        tmp25 = tl.where(tmp2, tmp23, tmp24)
        tmp26 = tl.broadcast_to(tmp25, [XBLOCK, RBLOCK])
        tmp28 = _tmp27 ^ tmp26
        _tmp27 = tl.where(rmask & xmask, tmp28, _tmp27)
    tmp27 = tl.xor_sum(_tmp27, 1)[:, None]
    tl.store(out_ptr0 + (x0), tmp27, xmask)

src = triton.compiler.ASTSource(fn=triton_fn, signature="*i64, *i32, *i32, i32, i32, i32")
test = triton.compile(src)
jataylo commented 3 months ago

Note this is blocking us creating reproducers for two other failing categories:

Once unblocked I will create reproducers for the above

zhanglx13 commented 3 months ago

triton_gpu.local_load

This was recently added in upstream and not pulled in our fork. @jataylo Are you using upstream or the fork?

jataylo commented 3 months ago

triton_gpu.local_load

This was recently added in upstream and not pulled in our fork. @jataylo Are you using upstream or the fork?

This is using the upstream backend, I can raise these issues at openai/triton if we think this is more appropriate

micmelesse commented 3 months ago

This branch fixes the issue, https://github.com/micmelesse/triton/tree/micmelesse/pytorch_2. We will work to upstream this.

jataylo commented 3 months ago

@micmelesse Thank you! I can confirm the UT is passing with this change. I'll keep the issue open until upstream PR is scoped.

zhanglx13 commented 3 months ago

This upstream PR should fix the issue: https://github.com/openai/triton/pull/3563