triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.47k stars 1.66k forks source link

Runtime error when I try to compile a LLVM IR use compile function. #5214

Open Ghjk94522 opened 1 day ago

Ghjk94522 commented 1 day ago

Describe the bug

When I use python/triton/compiler/compiler.py::compile to compile a llvm IR, I got a RuntimeError:

loc("/workspace/triton/python/test/single_dot.llir":1:1): error: unexpected character
Traceback (most recent call last):
  File "/workspace/triton/python/test/ir_runner.py", line 56, in <module>
    ms = triton.testing.do_bench(fn)
  File "/workspace/triton/python/triton/testing.py", line 117, in do_bench
    fn()
  File "/workspace/triton/python/test/ir_runner.py", line 55, in <lambda>
    fn = lambda: runner.run(non_constexpr_vals)
  File "/workspace/triton/python/test/ir_runner.py", line 29, in run
    kernel = self.compile(self.src_path, target=target)
  File "/workspace/triton/python/triton/compiler/compiler.py", line 227, in compile
    src = IRSource(src, context, backend)
  File "/workspace/triton/python/triton/compiler/compiler.py", line 111, in __init__
    self.module = ir.parse_mlir_module(self.path, context)
RuntimeError: Parse MLIR file failed.

While the python script is:

import torch
import triton
from collections import defaultdict

class IRRunner:
    def __init__(self, src_path):
        from triton.compiler import CompiledKernel, compile, IRSource, make_backend

        self.CompiledKernel = CompiledKernel
        self.compile = compile
        self.IRSource = IRSource
        self.make_backend = make_backend
        self.src_path = src_path
        self.cache = defaultdict(dict)
        pass

    def run(self, non_constexpr_vals):
        from triton.runtime.driver import driver

        device = driver.active.get_current_device()
        stream = driver.active.get_current_stream(device)
        target = driver.active.get_current_target()

        key = self.src_path
        kernel = self.cache[device].get(key, None)

        if kernel is None:
            kernel = self.compile(self.src_path, target=target)
            self.cache[device][key] = kernel

        kernel.run(
            1, # grid0
            1, # grid1
            1, # grid2
            stream,
            kernel.function,
            kernel.packed_metadata,
            None,
            self.CompiledKernel.launch_enter_hook,
            self.CompiledKernel.launch_exit_hook,
            *non_constexpr_vals
        )

        return kernel

if __name__ == "__main__":
    path = "/workspace/triton/python/test/store_mfma.ttgir"
    runner = IRRunner(path)
    tensor_rand = torch.randn((128, 128), dtype=torch.float16, device='cuda')
    non_constexpr_vals = (tensor_rand, 128, 128, 128)
    fn = lambda: runner.run(non_constexpr_vals)
    ms = triton.testing.do_bench(fn)
    print(f'triton.testing.do_bench time:{ms*1000:.2f} us')

and the llvm IR is

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7"
target triple = "amdgcn-amd-amdhsa"

@global_smem = external local_unnamed_addr addrspace(3) global [0 x i8], align 16

; Function Attrs: mustprogress nofree nounwind willreturn
define amdgpu_kernel void @dot_kernel(ptr addrspace(1) inreg nocapture readonly %0, i32 inreg %1, ptr addrspace(1) inreg nocapture readonly %2, i32 inreg %3, ptr addrspace(1) inreg nocapture readnone %4, i32 inreg %5, ptr addrspace(1) inreg nocapture writeonly %6, i32 inreg %7, ptr addrspace(1) inreg nocapture readnone %8) local_unnamed_addr #0 !dbg !3 {
  %10 = tail call i32 @llvm.amdgcn.workitem.id.x(), !dbg !6
  ...
  ret void, !dbg !21
}

And I have checked the init func in IRSource class:

class IRSource:

    def __init__(self, path, context, backend):
        self.path = path
        path = Path(path)
        self.ext = path.suffix[1:]
        self.src = path.read_text()
        ir.load_dialects(context)
        backend.load_dialects(context)

        # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
        # TODO - replace with a proper parser
        if self.ext == "ptx":
            match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
            self.name = match.group(1)
            signature = match.group(2)
            types = re.findall(arg_type_pattern[self.ext], signature)
            self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
        else:
            self.module = ir.parse_mlir_module(self.path, context)
            fn_name = self.module.get_entry_func_name()
            self.name = "@" + fn_name
            funcOp = self.module.get_function(fn_name)
            func_ty = self.module.get_function_signature(funcOp)
            self.signature = {k: ty for k, ty in enumerate(func_ty)}

Maybe the ir.parse_mlir_module cannot parse the LLVM IR? Or it will just fail in amd backend?

I have no nv GPU to test, maybe someone can help me or tell me is my code wrong?

Environment details

Triton: main branch commit id: d5ba6ac

GPU: AMD MI100

Jokeren commented 23 hours ago

cc @zhanglx13

peterbell10 commented 23 hours ago

Does it work with an mlir file using the LLVMIR dialect? We're parsing the input as mlir here so I wouldn't expect llvm IR to parse correctly.

Ghjk94522 commented 14 hours ago

Does it work with an mlir file using the LLVMIR dialect? We're parsing the input as mlir here so I wouldn't expect llvm IR to parse correctly.

I can use ttir and ttgir to compile with IRSource class in compiler.py, but here I want to the llir is correct or not. So how can I test it? Maybe there is no api for this function(directly compile and run a llir)?