Open amjames opened 6 months ago
NB: Also working on fixing inductor's codegen for this situation pytorch-126195
I think I have a fix for this, but I am not sure the semantics of tl.constexpr
require it. If you explicitly "dereference" the constexpr
with .value
, then this problem arises like you said because you have __eq__
attribute in _apply_binary_method
(https://github.com/triton-lang/triton/blob/main/python/triton/compiler/code_generator.py#L541) instead of the constexpr __eq__
attribute. I added a conditional check + constexpr unwrap for the situation where the rhs is constexpr and the lhs is not. However, when writing tests I couldn't think of a situation where one should explicitly dereference the constexpr type - if _apply_binary_method
is being called, then I think both the lhs
and rhs
should be constexpr
by definition. The previous situation only occurs because there are globals which are being treated as constexpr but are not explicitly marked constexpr - once marked constexpr, the .value
"dereference" (Is that the right word?) is unnecessary. So perhaps the .value
is really the bug and should be removed?
I have a branch https://github.com/alexbaden/triton/tree/alex/fix_globals_constexpr that implements this fix that I would be happy to PR, if indeed this is a triton bug and not an inductor bug / invalid dereference of tl.constexpr
. I ran your repro script on my branch and it passed:
# Global defines:
STRING_CONSTANT_C: tl.constexpr = '...'
STRING_CONSTANT_OBJ_C = tl.constexpr('...')
# kernel_signature:
def kernel(..., CONSTANT_NAME: tl.constexpr):
With expr if CONSTANT_NAME == STRING_CONSTANT_C: called w/ CONSTANT_NAME='...' -- > WORKS
With expr if CONSTANT_NAME == STRING_CONSTANT_OBJ_C: called w/ CONSTANT_NAME='...' -- > WORKS
With expr if CONSTANT_NAME == STRING_COSNTANT_OBJ_C: called w/ CONSTANT_NAME='...' -- > WORKS
@alexbaden that patch would fix the reproducer I wrote, but that was more of a demonstration rather than an exhaustive set of tests. The issue I have is really with the asymmetry between the handling of (locals, kernel arguments) and (capture globals) with respect to the annotation. The former group is intercepted at the assignment and constexpr(value)
is stored in the local scope dictionary, the latter is not (the annotation is only checked when we dereference the name to see if we are 'allowed' to access the global). Consider this:
GLOBAL: tl.constexpr = 1
@triton.jit
def kernel(arg: tl.constexpr):
local_var: tl.constexpr = 1
if local_var == arg:
...
if local_var == GLOBAL:
...
if arg == local_var:
...
if arg == GLOBAL:
...
if GLOBAL == local_var:
...
if GLOBAL == arg:
...
The last two clauses trigger a compile error and you need to use .value
on the RHS to work around it. I think that is wrong.
Even worse, accessing a captured global requires either it carry the type tl.constexpr
or the annotation. However in the latter case there is no guard on assigning to it.
GLOBAL: tl.constexpr = 1
@triton.jit(debug=True)
def kernel(arg: tl.constexpr):
GLOBAL = 7
tl.device_assert(local_var == GLOBAL) # runtime failure
TLDR: After #3762 global variables which are captured by a kernel must be
tl.constexpr
or annotated as such. It is surprising to me that the kernel argument which has an annotation is actually an object of typeconstexpr
when theCodeGenerator.visit
is running, but the captured global is not. Either that should be fixed, or the suggestion in the error message should only recommend globals be defined asVAR = tl.constexpr(<value>)
.Details
I had some code that looks like this (actual original is from pytorch tests:
After getting the new error about globals needing to be
tl.constexpr
I tried defining STRING_CONSTANT_C` is defined globally like this:Compilation of fails w/
if conditionals can only accept values of type {int, NoneType, bool}, not objects of type NotImplementedType
. Digging into that a bit I realize that this is parsing asstr.__eq__
comparing a string to atl.constexpr
object. So I modify the kernel source so the conditional usesSTRING_CONSTANT_C.value
, which works.The other recommendation from the error message introduced by #3762 is to use an annotation on the captured global, trying that out
That fails with the modified conditional and works with the original source.
Proposal
Why not translate these captured variables to always be
tl.constexpr
instances the way that arguments with the annotation are handled?Reproducer script: https://gist.github.com/amjames/973b378f7c0fa8c92b6c92d05d90547b