triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.53k stars 1.52k forks source link

breaking change to constexpr in triton 3.0.0 #4321

Open iclementine opened 1 month ago

iclementine commented 1 month ago

In triton 2.2.0 ~ 2.3.x. The following code works fine. Assining math.log2(math.e) to a tl.constexpr inside a jitted function is fine, as expected. After all, the rhs for an assignment to a tl.constexpr is evaluated at compiled time, which should not be restricted by triton's limitations.

import math
import torch
import triton
from triton import language as tl

# well, exp2 may come from various places at various versions of triton
try:
    from triton.language.extra.cuda.libdevice import exp2
except ImportError:
    try:
        from triton.language.math import exp2
    except ImportError:
        from triton.language.libdevice import exp2

@triton.jit
def sigmoid(in_ptr, out_ptr, n, TILE_SIZE: tl.constexpr):
    tids = tl.program_id(0) * TILE_SIZE + tl.arange(0, TILE_SIZE)
    mask = tids < n
    x = tl.load(in_ptr + tids, mask=mask)
    log2e: tl.constexpr = math.log2(math.e)
    # log2e: tl.constexpr = 1.4426950408889634
    out = 1 / (1 + exp2(-x.to(tl.float32) * log2e))
    tl.store(out_ptr + tids, out, mask=mask)

x = torch.randn(1000, device="cuda")
y = torch.empty_like(x)
tile_size = 512
grid = triton.cdiv(x.numel(), tile_size), 1, 1
sigmoid[grid](x, y, x.numel(), tile_size)

But in triton 3.0.0. I get an error

    125     return module.startswith(TRITON_MODULE)
    127 func = self.visit(node.func)
--> 128 assert func is None or is_triton_builtin(func) or isinstance(
    129     func, JITFunction
    130 ), f'Function "{func.__name__}" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this'
    132 # Traverse arguments as well as node.func so we can find JITFunctions
    133 # passed to tl.reduce or tl.associative_scan as the combine_fn
    134 for obj in itertools.chain(
    135     (func, ),
    136         map(self.visit, node.args),
    137     (self.visit(kw.value) for kw in node.keywords),
    138 ):

AssertionError: Function "log2" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this

I think triton should have a clear language specification, otherwise it breaks "backward-compatibility" frequently.

Jokeren commented 1 month ago

I think this problem has been fixed in triton/main

iclementine commented 1 month ago

I will test it. Thank you~