Open wyc1997 opened 3 hours ago
Can you provide me with an end-to-end example?
import torch
import triton
import triton.language as tl
def example(x):
b, l = x.shape
grid = (b, l)
with torch.cuda.device(x.device.index):
_example_kernel[grid](
x,
x.stride(0),x.stride(1)
)
return x
@triton.jit
def _example_kernel(
X,
stride_x_b, stride_x_l,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
x_ptr = X + pid_b * stride_x_b + pid_l * stride_x_l
tl.store(x_ptr, 1, mask=(0==1))
if __name__ == "__main__":
a = -torch.ones((10, 10), device='cuda')
a = example(a)
print(a)
I am running with triton==3.0.0, python==3.10.0, torch=='2.3.1+cu121'.
The above code will produce different results when running with TRITON_INTERPRET=1
and TRITON_INTERPRET=0
With TRITON_INTERPRET=1
, the values of a
will become all 1s and with TRITON_INTERPRET=0
, the values of a
remain as -1
Please check if my PR fixes the problem. Will add tests before merge
tl.store(store_ptr, 1, mask=(0==1))
store_ptr here is a single element pointer. When running with TRITON_INTERPRET=1, the above line seem to ignore the mask and always stores regardless of the mask's false value. When running normally with TRITON_INTERPRET=0, the behavior is normal and nothing gets stored.There seems to be a bug with the interpreter for this. Any idea what could be wrong?