triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.68k stars 1.53k forks source link

Should @core.extern be part of the libdevice interface? #4509

Open int3 opened 1 month ago

int3 commented 1 month ago

I've been thinking about how to support some libdevice operations that don't map cleanly to existing vector math libraries on the CPU. It seems to me that things like isnan could be implemented reasonably cleanly in Triton itself. E.g. for the fp32 case:

@jit
def isnan(arg0):
    return (arg0.to(core.dtype("uint32"), bitcast=True) & (127 << 23)) == (127 << 23)

However, libdevice.isnan as defined in the interface file triton/language/extra/libdevice.py is not marked as @jit but rather as @core.extern. So we can't use the above syntax; instead we need to explicitly specify the _builder argument:

@core.extern
def isnan(arg0, _builder):
    return (arg0.to(core.dtype("uint32"), bitcast=True, _builder=_builder).__and__(127 << 23, _builder=_builder)).__eq__(127 << 23, _builder=_builder)

which is pretty ugly.

So I'm wondering if @core.extern vs @jit should be part of the libdevice interface. On one hand, it seems like an implementation detail; on the other hand, the "calling convention" of a function is traditionally part of its interface. I think making it into part of the implementation is quite doable; we would simply have to replace the invoke-time dispatch function with something that does the mapping at libdevice-module-creation time.

Another approach would be to write a simple ast transform that inserts these _builder arguments, so if we wanted to declare an extern function with an implicit builder, it would be something like

@auto_builder
@core.extern
def isnan(arg0):
    ...

This might be simpler / less invasive. Would love to hear your thoughts.

int3 commented 1 month ago

Another approach would be to write a simple ast transform that inserts these _builder arguments

I'm actually not sure if this is correct / safe; are we guaranteed that the function will get JIT-compiled? Or do we have to actually mark it with @jit? Looking at the MLIR, it looks like things get JIT-compiled regardless of whether we add @jit, but I'm not sure if it's just working by accident.