triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.54k stars 1.67k forks source link

Setting the environment variable TRITON_INTERPRET causes the kernel function to not be able to receive reserved keyword arguments. #5164

Closed 0Addicted0 closed 1 week ago

0Addicted0 commented 1 week ago

Describe the bug

Thanks for triton team's excellent work 👍 Describe the bug When I tried to put the reserved keyword arguments in the kernel function's parameters and set TRITON_INTERPRET=1, I encountered an error of not finding these arguments (this will not affect you if you do not set this environment variable) I'm not sure if this is a bug or a trivial problem(or an operation not allowed). minimal compelete example

import os
os.environ["TRITON_INTERPRET"] = "1"
import torch
import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({}, num_stages=2, num_warps=1),
    ],
    key=["BLOCK_SIZE"],
)
@triton.jit
def simple_kernel(
    a_ptr,
    out_ptr,
    n_cols,
    BLOCK_SIZE: tl.constexpr,
    num_stages: tl.constexpr,
):
    _sum = 0.0
    for idx in tl.range(0, n_cols, BLOCK_SIZE, num_stages=num_stages):
        off = idx + tl.arange(0, BLOCK_SIZE)
        a_ptrs = a_ptr + off
        a = tl.load(a_ptrs, mask=off < n_cols, other=0.0)
        _sum += tl.sum(a)
    tl.store(out_ptr, _sum)

N = 12
a = torch.randn((N,), device="cuda", dtype=torch.float16)
triton_out = torch.zeros((1,), device="cuda", dtype=torch.float16)
BLOCK_SIZE = 4
simple_kernel[(1, 1, 1)](
    a_ptr=a,
    out_ptr=triton_out,
    n_cols=N,
    BLOCK_SIZE=BLOCK_SIZE,
)
torch_out = torch.sum(a)
print(torch.allclose(triton_out, torch_out, atol=1e-2, rtol=1e-2))

Error message

Traceback (most recent call last): File "/user/test/patch.py", line 39, in simple_kernel[(1, 1, 1)]( File "/user/miniconda3/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in return lambda *args, kwargs: self.run(grid=grid, warmup=False, *args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/user/miniconda3/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 171, in run ret = self.fn.run( ^^^^^^^^^^^^ File "/user/miniconda3/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 1108, in run return GridExecutor(self.fn, self.arg_names, grid)(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/user/miniconda3/lib/python3.12/site-packages/triton/runtime/interpreter.py", line 1082, in call args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/user/miniconda3/lib/python3.12/inspect.py", line 1583, in getcallargs _missing_arguments(f_name, req, True, arg2value) File "/user/miniconda3/lib/python3.12/inspect.py", line 1512, in _missing_arguments raise TypeError("%s() missing %i required %s argument%s: %s" % TypeError: simple_kernel() missing 1 required positional argument: 'num_stages'

Possible direct cause

https://github.com/triton-lang/triton/blob/1cf06c5e1982eba8f17062e1c6c3d3fa458597b2/python/triton/runtime/interpreter.py#L1081 removed all reserved keywords arguments

Possible solution

        req_args = inspect.getfullargspec(self.fn)[0]
        kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS or k in req_args}

Environment details

Triton: 3.0.0 GPU: Tesla V100-PCIE-32GB Python: 3.12