triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.09k stars 1.6k forks source link

Error with int16 type after using associative_scan #3063

Closed kuviki closed 8 months ago

kuviki commented 8 months ago

Issue Description

Creating a new axis after using associative_scan with the int16 type raises an error. If the commented-out line of code is used instead of the problematic line, the code runs correctly. Also, changing the dtype of index to torch.int32 allows both lines to run correctly.

Additionally, after running locally, device_print does not display any output initially and requires multiple calls to the kernel function to show output. Is there any missing operation like flushing the buffer for device_print?

Triton version is 2.2.0.

Reproduction Code

import torch
import triton
import triton.language as tl

@triton.jit
def _end_combine(a, b):
    if a == -1:
        return -1
    return b

@triton.jit
def kernel(index_ptr, value_ptr, BLOCK_SIZE: tl.constexpr):
    ids = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    index = tl.load(index_ptr + ids).to(tl.int32)

    index = tl.associative_scan(index, axis=0, combine_fn=_end_combine)[:, None]  # Problematic line
    # index = tl.associative_scan(index[:, None], axis=0, combine_fn=_end_combine)

    value = tl.load(value_ptr + index, mask=index != -1, other=0.)
    tl.device_print('value: ', value)

if __name__ == '__main__':
    torch.manual_seed(0)
    index = torch.randint(128, (128,), dtype=torch.int16, device='cuda')
    index[8] = -1
    value = torch.rand(128, device='cuda')

    kernel[(1,)](index, value, 32, num_warps=1)

Expected Output

``` pid (0, 0, 0) idx ( 0, 0) value: 0.535563 pid (0, 0, 0) idx ( 1, 0) value: 0.577269 pid (0, 0, 0) idx ( 2, 0) value: 0.825175 pid (0, 0, 0) idx ( 3, 0) value: 0.628812 pid (0, 0, 0) idx ( 4, 0) value: 0.359353 pid (0, 0, 0) idx ( 5, 0) value: 0.609255 pid (0, 0, 0) idx ( 6, 0) value: 0.151971 pid (0, 0, 0) idx ( 7, 0) value: 0.932502 pid (0, 0, 0) idx ( 8, 0) value: 0.000000 pid (0, 0, 0) idx ( 9, 0) value: 0.000000 pid (0, 0, 0) idx (10, 0) value: 0.000000 pid (0, 0, 0) idx (11, 0) value: 0.000000 pid (0, 0, 0) idx (12, 0) value: 0.000000 pid (0, 0, 0) idx (13, 0) value: 0.000000 pid (0, 0, 0) idx (14, 0) value: 0.000000 pid (0, 0, 0) idx (15, 0) value: 0.000000 pid (0, 0, 0) idx (16, 0) value: 0.000000 pid (0, 0, 0) idx (17, 0) value: 0.000000 pid (0, 0, 0) idx (18, 0) value: 0.000000 pid (0, 0, 0) idx (19, 0) value: 0.000000 pid (0, 0, 0) idx (20, 0) value: 0.000000 pid (0, 0, 0) idx (21, 0) value: 0.000000 pid (0, 0, 0) idx (22, 0) value: 0.000000 pid (0, 0, 0) idx (23, 0) value: 0.000000 pid (0, 0, 0) idx (24, 0) value: 0.000000 pid (0, 0, 0) idx (25, 0) value: 0.000000 pid (0, 0, 0) idx (26, 0) value: 0.000000 pid (0, 0, 0) idx (27, 0) value: 0.000000 pid (0, 0, 0) idx (28, 0) value: 0.000000 pid (0, 0, 0) idx (29, 0) value: 0.000000 pid (0, 0, 0) idx (30, 0) value: 0.000000 pid (0, 0, 0) idx (31, 0) value: 0.000000 ```

Actual Output

python: /root/.triton/llvm/llvm-5e5a22ca-centos-x64/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::triton::gpu::BlockedEncodingAttr; From = mlir::Attribute]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
/tmp/tmp8c8vfbcd: line 3: 16639 Aborted                 python /mnt/c/Users/kuviki/PycharmProjects/thinking-ml/testbed.py
ERROR conda.cli.main_run:execute(49): `conda run python /mnt/c/Users/kuviki/PycharmProjects/thinking-ml/testbed.py` failed. (See above for error)
jlebar commented 8 months ago

Thanks, this is a real bug. cc @ThomasRaoux

ThomasRaoux commented 8 months ago

thanks for pointing that out, yes currently tl. associative_scan only works on blocked layout so we need to prevent other layouts. I'll send a patch

ThomasRaoux commented 8 months ago

https://github.com/openai/triton/pull/3065