triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.27k stars 1.63k forks source link

tl.store with single element pointer seems to ignore mask when running on interpreter #5048

Open wyc1997 opened 3 hours ago

wyc1997 commented 3 hours ago

tl.store(store_ptr, 1, mask=(0==1)) store_ptr here is a single element pointer. When running with TRITON_INTERPRET=1, the above line seem to ignore the mask and always stores regardless of the mask's false value. When running normally with TRITON_INTERPRET=0, the behavior is normal and nothing gets stored.

There seems to be a bug with the interpreter for this. Any idea what could be wrong?

Jokeren commented 2 hours ago

Can you provide me with an end-to-end example?

wyc1997 commented 2 hours ago
import torch
import triton
import triton.language as tl
def example(x):
    b, l = x.shape
    grid = (b, l)
    with torch.cuda.device(x.device.index):
        _example_kernel[grid](
            x, 
            x.stride(0),x.stride(1)
        )
    return x

@triton.jit
def _example_kernel(
    X, 
    stride_x_b, stride_x_l,
):
    pid_b = tl.program_id(0)
    pid_l = tl.program_id(1)

    x_ptr = X + pid_b * stride_x_b + pid_l * stride_x_l
    tl.store(x_ptr, 1, mask=(0==1))

if __name__ == "__main__":
    a = -torch.ones((10, 10), device='cuda')

    a = example(a)
    print(a)

I am running with triton==3.0.0, python==3.10.0, torch=='2.3.1+cu121'. The above code will produce different results when running with TRITON_INTERPRET=1 and TRITON_INTERPRET=0 With TRITON_INTERPRET=1, the values of a will become all 1s and with TRITON_INTERPRET=0, the values of a remain as -1

Jokeren commented 1 hour ago

Please check if my PR fixes the problem. Will add tests before merge