triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.37k stars 1.64k forks source link

Inconsistency between constants as arguments and captured globals #3924

Open amjames opened 6 months ago

amjames commented 6 months ago

TLDR: After #3762 global variables which are captured by a kernel must be tl.constexpr or annotated as such. It is surprising to me that the kernel argument which has an annotation is actually an object of type constexpr when the CodeGenerator.visit is running, but the captured global is not. Either that should be fixed, or the suggestion in the error message should only recommend globals be defined as VAR = tl.constexpr(<value>).

Details

I had some code that looks like this (actual original is from pytorch tests:

STRING_CONSTANT_C = 'value'

@triton.jit
def kernel(in_ptr, out_ptr,  n_elements, BLOCK_SIZE: "tl.constexpr", CONSTANT_NAME: "tl.constexpr"):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(in_ptr0 + offsets, mask=mask)
    if CONSTANT_NAME.value == STRING_CONSTANT_C:
        output = 2 * x
    tl.store(out_ptr + offsets, output, mask=mask)

After getting the new error about globals needing to be tl.constexpr I tried defining STRING_CONSTANT_C` is defined globally like this:

STRING_CONSTANT_C = tl.constexpr('value')

Compilation of fails w/ if conditionals can only accept values of type {int, NoneType, bool}, not objects of type NotImplementedType. Digging into that a bit I realize that this is parsing as str.__eq__ comparing a string to a tl.constexpr object. So I modify the kernel source so the conditional uses STRING_CONSTANT_C.value, which works.

The other recommendation from the error message introduced by #3762 is to use an annotation on the captured global, trying that out

STRING_CONSTANT_C: tl.constexpr = 'value'

That fails with the modified conditional and works with the original source.

Proposal

Why not translate these captured variables to always be tl.constexpr instances the way that arguments with the annotation are handled?

Reproducer script: https://gist.github.com/amjames/973b378f7c0fa8c92b6c92d05d90547b

amjames commented 6 months ago

NB: Also working on fixing inductor's codegen for this situation pytorch-126195

alexbaden commented 6 months ago

I think I have a fix for this, but I am not sure the semantics of tl.constexpr require it. If you explicitly "dereference" the constexpr with .value, then this problem arises like you said because you have __eq__ attribute in _apply_binary_method (https://github.com/triton-lang/triton/blob/main/python/triton/compiler/code_generator.py#L541) instead of the constexpr __eq__ attribute. I added a conditional check + constexpr unwrap for the situation where the rhs is constexpr and the lhs is not. However, when writing tests I couldn't think of a situation where one should explicitly dereference the constexpr type - if _apply_binary_method is being called, then I think both the lhs and rhs should be constexpr by definition. The previous situation only occurs because there are globals which are being treated as constexpr but are not explicitly marked constexpr - once marked constexpr, the .value "dereference" (Is that the right word?) is unnecessary. So perhaps the .value is really the bug and should be removed?

I have a branch https://github.com/alexbaden/triton/tree/alex/fix_globals_constexpr that implements this fix that I would be happy to PR, if indeed this is a triton bug and not an inductor bug / invalid dereference of tl.constexpr. I ran your repro script on my branch and it passed:

# Global defines:
STRING_CONSTANT_C: tl.constexpr = '...'
STRING_CONSTANT_OBJ_C = tl.constexpr('...')
# kernel_signature: 
def kernel(..., CONSTANT_NAME: tl.constexpr):
With expr if CONSTANT_NAME == STRING_CONSTANT_C: called w/ CONSTANT_NAME='...' -- > WORKS
With expr if CONSTANT_NAME == STRING_CONSTANT_OBJ_C: called w/ CONSTANT_NAME='...' -- > WORKS
With expr if CONSTANT_NAME == STRING_COSNTANT_OBJ_C: called w/ CONSTANT_NAME='...' -- > WORKS
amjames commented 6 months ago

@alexbaden that patch would fix the reproducer I wrote, but that was more of a demonstration rather than an exhaustive set of tests. The issue I have is really with the asymmetry between the handling of (locals, kernel arguments) and (capture globals) with respect to the annotation. The former group is intercepted at the assignment and constexpr(value) is stored in the local scope dictionary, the latter is not (the annotation is only checked when we dereference the name to see if we are 'allowed' to access the global). Consider this:

GLOBAL: tl.constexpr = 1

@triton.jit
def kernel(arg: tl.constexpr):
    local_var: tl.constexpr = 1

    if local_var == arg:
        ...
    if local_var == GLOBAL:
        ...
    if arg == local_var:
        ...
    if arg == GLOBAL:
        ...
    if GLOBAL == local_var:
        ...
    if GLOBAL == arg:
        ...

The last two clauses trigger a compile error and you need to use .value on the RHS to work around it. I think that is wrong.

Even worse, accessing a captured global requires either it carry the type tl.constexpr or the annotation. However in the latter case there is no guard on assigning to it.

GLOBAL: tl.constexpr = 1
@triton.jit(debug=True)
def kernel(arg: tl.constexpr):
    GLOBAL = 7
    tl.device_assert(local_var == GLOBAL) # runtime failure