ROCm / triton

Development repository for the Triton language and compiler
MIT License
80 stars 22 forks source link

[Upstream Backend] [PyTorch UT]: `error: failed to legalize operation 'triton_gpu.local_load' that was explicitly marked illegal` #553

Closed jataylo closed 2 months ago

jataylo commented 3 months ago

Problem Description

Environment: Docker image: rocm/pytorch-private:rocm_inductor_triton_upstream_migration_v1 Triton branch: https://github.com/jataylo/triton/tree/jack-triton-inductor-migration Pytorch branch: https://github.com/pytorch/pytorch/tree/rocm-inductor-hip-device

Unit test TORCHINDUCTOR_COMPILE_THREADS=1 python test/inductor/test_benchmark_fusion.py --verbose -k "test_avoid_register_spilling_cuda"

Reproducer:

import torch
import functools
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

@triton.jit
def triton_mm(arg_A, arg_B, out_ptr0):
    GROUP_M : tl.constexpr = 8
    EVEN_K : tl.constexpr = True
    ALLOW_TF32 : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    B_PROLOGUE_CAST_TYPE : tl.constexpr = None
    BLOCK_M : tl.constexpr = 64
    BLOCK_N : tl.constexpr = 64
    BLOCK_K : tl.constexpr = 32
    matrix_instr_nonkdim : tl.constexpr = 0
    A = arg_A
    B = arg_B

    M = 2048
    N = 2048
    K = 2048
    if M * N == 0:
        # early exit due to zero-size input(s)
        return
    stride_am = 2048
    stride_ak = 1
    stride_bk = 1
    stride_bn = 2048

    # based on triton.ops.matmul
    pid = tl.program_id(0)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        if B_PROLOGUE_CAST_TYPE is not None:
            b = b.to(B_PROLOGUE_CAST_TYPE)
        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + (2048*idx_m)
    tl.store(out_ptr0 + (tl.broadcast_to(xindex, mask.shape)), acc, mask)

run_method = triton_mm.run
tensor1 = torch.randn([2048, 2048], dtype=torch.float16).cuda()
tensor2 = torch.randn([2048, 2048], dtype=torch.float16).cuda()
input_tensors = tensor1, tensor2
output_tensor = torch.zeros([2048, 2048], dtype=torch.float16).cuda()
extra_args = ()
warmup_arg = {'warmup': False}

fn = functools.partial(
    run_method,
    *input_tensors,
    output_tensor,
    *extra_args,
    grid=(1024, 1, 1),
    **warmup_arg,
    num_stages=1,
    num_warps=4
)

triton.testing.do_bench(fn)

Traceback

loc("test.py":55:24): error: failed to legalize operation 'triton_gpu.local_load' that was explicitly marked illegal
Traceback (most recent call last):
  File "test.py", line 97, in <module>
    triton.testing.do_bench(fn)
  File "/root/triton/python/triton/testing.py", line 100, in do_bench
    fn()
  File "/root/triton/python/triton/runtime/jit.py", line 401, in run
    self.cache[device][key] = compile(
  File "/root/triton/python/triton/compiler/compiler.py", line 268, in compile
    next_module = compile_ir(module, metadata)
  File "/root/triton/python/triton/backends/amd/compiler.py", line 223, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, 90)
  File "/root/triton/python/triton/backends/amd/compiler.py", line 163, in make_llir
    pm.run(mod)
RuntimeError: PassManager::run failed

Operating System

-

CPU

-

GPU

AMD Instinct MI250X

ROCm Version

ROCm 6.0.0

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

zhanglx13 commented 2 months ago

fixed by ToT upstream