Creating a new axis after using associative_scan with the int16 type raises an error. If the commented-out line of code is used instead of the problematic line, the code runs correctly. Also, changing the dtype of index to torch.int32 allows both lines to run correctly.
Additionally, after running locally, device_print does not display any output initially and requires multiple calls to the kernel function to show output. Is there any missing operation like flushing the buffer for device_print?
Triton version is 2.2.0.
Reproduction Code
import torch
import triton
import triton.language as tl
@triton.jit
def _end_combine(a, b):
if a == -1:
return -1
return b
@triton.jit
def kernel(index_ptr, value_ptr, BLOCK_SIZE: tl.constexpr):
ids = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
index = tl.load(index_ptr + ids).to(tl.int32)
index = tl.associative_scan(index, axis=0, combine_fn=_end_combine)[:, None] # Problematic line
# index = tl.associative_scan(index[:, None], axis=0, combine_fn=_end_combine)
value = tl.load(value_ptr + index, mask=index != -1, other=0.)
tl.device_print('value: ', value)
if __name__ == '__main__':
torch.manual_seed(0)
index = torch.randint(128, (128,), dtype=torch.int16, device='cuda')
index[8] = -1
value = torch.rand(128, device='cuda')
kernel[(1,)](index, value, 32, num_warps=1)
Issue Description
Creating a new axis after using
associative_scan
with the int16 type raises an error. If the commented-out line of code is used instead of the problematic line, the code runs correctly. Also, changing thedtype
ofindex
totorch.int32
allows both lines to run correctly.Additionally, after running locally,
device_print
does not display any output initially and requires multiple calls to the kernel function to show output. Is there any missing operation like flushing the buffer fordevice_print
?Triton version is 2.2.0.
Reproduction Code
Expected Output
Actual Output