microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
194 stars 44 forks source link

[Bug]: Failed to lower tt.bitcast of !tt.ptr<i1> #192

Open Nullkooland opened 5 days ago

Nullkooland commented 5 days ago

Triton python code

@triton_heuristics.pointwise(
    size_hints=[16384], 
    filename=__file__,
    triton_meta={'signature': {0: '*i16', 1: '*i16', 2: '*i1', 3: 'i32'}, 'device': DeviceProperties(type='cpu', index=None, cc='', major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(2, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_isin_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 4, 'num_reduction': 0, ...}
)
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel: int, XBLOCK : tl.constexpr):
    xnumel = 16384
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)
    x0 = xindex
    tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[16384], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]))
    tmp1 = tl.load(in_ptr1 + (0))
    tmp2 = tl.broadcast_to(tmp1, [XBLOCK])
    tmp4 = tl.load(in_ptr1 + (1))
    tmp5 = tl.broadcast_to(tmp4, [XBLOCK])
    tmp8 = tl.load(in_ptr1 + (2))
    tmp9 = tl.broadcast_to(tmp8, [XBLOCK])
    tmp3 = tmp0 == tmp2
    tmp6 = tmp0 == tmp5
    tmp7 = tmp3 | tmp6
    tmp10 = tmp0 == tmp9
    tmp11 = tmp7 | tmp10
    tl.store(tl.make_block_ptr(out_ptr0, shape=[16384], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp11, [XBLOCK]).to(tl.int8))

Triton IR

module {
  tt.func public @triton_(%arg0: !tt.ptr<i16>, %arg1: !tt.ptr<i16>, %arg2: !tt.ptr<i1> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i64 = arith.constant 1 : i64
    %c16384_i64 = arith.constant 16384 : i64
    %c256_i32 = arith.constant 256 : i32
    %0 = tt.get_program_id x : i32
    %1 = arith.muli %0, %c256_i32 : i32
    %2 = tt.make_tensor_ptr %arg0, [%c16384_i64], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<256xi16>>
    %3 = tt.load %2 : !tt.ptr<tensor<256xi16>>
    %4 = tt.addptr %arg1, %c0_i32 : !tt.ptr<i16>, i32
    %5 = tt.load %4 : !tt.ptr<i16>
    %6 = tt.splat %5 : i16 -> tensor<256xi16>
    %7 = tt.addptr %arg1, %c1_i32 : !tt.ptr<i16>, i32
    %8 = tt.load %7 : !tt.ptr<i16>
    %9 = tt.splat %8 : i16 -> tensor<256xi16>
    %10 = tt.addptr %arg1, %c2_i32 : !tt.ptr<i16>, i32
    %11 = tt.load %10 : !tt.ptr<i16>
    %12 = tt.splat %11 : i16 -> tensor<256xi16>
    %13 = arith.cmpi eq, %3, %6 : tensor<256xi16>
    %14 = arith.cmpi eq, %3, %9 : tensor<256xi16>
    %15 = arith.ori %13, %14 : tensor<256xi1>
    %16 = arith.cmpi eq, %3, %12 : tensor<256xi16>
    %17 = arith.ori %15, %16 : tensor<256xi1>
    %18 = tt.bitcast %arg2 : !tt.ptr<i1> -> !tt.ptr<i8>
    %19 = tt.make_tensor_ptr %18, [%c16384_i64], [%c1_i64], [%1] {order = array<i32: 0>} : <tensor<256xi8>>
    %20 = arith.extui %17 : tensor<256xi1> to tensor<256xi8>
    tt.store %19, %20 : !tt.ptr<tensor<256xi8>>
    tt.return
  }
}

Crash log

triton-shared-opt --triton-to-linalg-experimental /.../triton_.ttir
gen_triton_kernel.py:25:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
    tmp1 = tl.load(in_ptr1 + (0))
                  ^
gen_triton_kernel.py:25:19: note: see current operation: %11 = tt.load %10 : !tt.ptr<i16>
gen_triton_kernel.py:25:19: remark: PtrAnalysis: Failed to rewrite LoadOp
    tmp1 = tl.load(in_ptr1 + (0))
                  ^
gen_triton_kernel.py:25:19: note: see current operation: %11 = tt.load %10 : !tt.ptr<i16>
gen_triton_kernel.py:27:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
❯ triton-shared-opt --triton-to-linalg-experimental /.../triton_.ttir
gen_triton_kernel.py:25:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
    tmp1 = tl.load(in_ptr1 + (0))
                  ^
gen_triton_kernel.py:25:19: note: see current operation: %11 = tt.load %10 : !tt.ptr<i16>
gen_triton_kernel.py:25:19: remark: PtrAnalysis: Failed to rewrite LoadOp
    tmp1 = tl.load(in_ptr1 + (0))
                  ^
gen_triton_kernel.py:25:19: note: see current operation: %11 = tt.load %10 : !tt.ptr<i16>
gen_triton_kernel.py:27:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
    tmp4 = tl.load(in_ptr1 + (1))
                  ^
gen_triton_kernel.py:27:19: note: see current operation: %15 = tt.load %14 : !tt.ptr<i16>
gen_triton_kernel.py:27:19: remark: PtrAnalysis: Failed to rewrite LoadOp
    tmp4 = tl.load(in_ptr1 + (1))
                  ^
gen_triton_kernel.py:27:19: note: see current operation: %15 = tt.load %14 : !tt.ptr<i16>
gen_triton_kernel.py:29:19: remark: PtrAnalysis: scalar loadOp will not be rewritten
    tmp8 = tl.load(in_ptr1 + (2))
                  ^
gen_triton_kernel.py:29:19: note: see current operation: %19 = tt.load %18 : !tt.ptr<i16>
gen_triton_kernel.py:29:19: remark: PtrAnalysis: Failed to rewrite LoadOp
    tmp8 = tl.load(in_ptr1 + (2))
                  ^
gen_triton_kernel.py:29:19: note: see current operation: %19 = tt.load %18 : !tt.ptr<i16>
gen_triton_kernel.py:36:31: error: 'arith.bitcast' op operand #0 must be signless-integer-or-float-like or memref of signless-integer or float, but got '!tt.ptr<i1>'
    tl.store(tl.make_block_ptr(out_ptr0, shape=[16384], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp11, [XBLOCK]).to(tl.int8))
                              ^
gen_triton_kernel.py:36:31: note: see current operation: %30 = "arith.bitcast"(%arg2) : (!tt.ptr<i1>) -> !tt.ptr<i8>

Additional information

The triton kernel is the codegen result of TorchInductor, the src torch program is:

a = torch.randint(low=-100, high=100, size=(128, 128), dtype=torch.int16, device=device)
b = torch.tensor(data=[-1, 0, 1], dtype=torch.int16, device=device)

@torch.compile(fullgraph=True)
def test_func(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    out = torch.isin(a, b)
    return out

out = test_func(a, b)

Looks like tiron-shared's triton-to-linalg lower pass cannot properly handle i1 (boolean) type?

parsifal-47 commented 5 days ago

I think this PR should address the issue: https://github.com/microsoft/triton-shared/pull/171

Nullkooland commented 2 days ago

I think this PR should address the issue: #171

Thanks, when will this PR be upstreamed?

parsifal-47 commented 1 day ago

I think this PR should address the issue: #171

Thanks, when will this PR be upstreamed?

It needs a review, I pinged @nhat-nguyen, but he could be busy, if you know somebody else with write permissions let me know