Open manbearian opened 9 months ago
Add a python script:
import triton
import triton.language as tl
@triton.jit
def triton_(out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 2688
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex % 7
x1 = (xindex // 7)
tmp0 = 0.0
tl.store(out_ptr0 + (x0 + (72*x1)), tmp0, xmask)
ret = triton.compile(triton_, signature="*fp32,i32", constants={"XBLOCK": 32}, device_type="cpu")
print(ret.asm["ttir"])
print(ret.asm["ttsharedir"])
print(ret.asm["llir"])
print(ret.asm["cpuasm"])
module {
tt.func public @triton__01(%arg0: !tt.ptr<f32, 1>, %arg1: i32) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<32xf32>
%cst_0 = arith.constant dense<72> : tensor<32xi32>
%cst_1 = arith.constant dense<7> : tensor<32xi32>
%cst_2 = arith.constant dense<2688> : tensor<32xi32>
%c32_i32 = arith.constant 32 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c32_i32 : i32
%2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32>
%3 = tt.splat %1 : (i32) -> tensor<32xi32>
%4 = arith.addi %3, %2 : tensor<32xi32>
%5 = arith.cmpi slt, %4, %cst_2 : tensor<32xi32>
%6 = arith.remsi %4, %cst_1 : tensor<32xi32>
%7 = arith.divsi %4, %cst_1 : tensor<32xi32>
%8 = arith.muli %7, %cst_0 : tensor<32xi32>
%9 = arith.addi %6, %8 : tensor<32xi32>
%10 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<32x!tt.ptr<f32, 1>>
%11 = tt.addptr %10, %9 : tensor<32x!tt.ptr<f32, 1>>, tensor<32xi32>
tt.store %11, %cst, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<32xf32>
tt.return
}
}
%20 = "arith.constant"() <{value = 7 : index}> : () -> index
%15 = "arith.divsi"(%12, %5) {MetaUse} : (tensor<32xi32>, tensor<32xi32>) -> tensor<32xi32>
encountered addptr operand produced by an unsupported operation
UNREACHABLE executed at /workspace/hongjing/triton2/third_party/triton_shared/lib/Analysis/PtrAnalysis.cpp:640!
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0. Program arguments: /workspace/hongjing/triton2/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt /tmp/tmpoizjes5t/tt.mlir --triton-to-linalg -o /tmp/tmpoizjes5t/ttshared.mlir
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
0 triton-shared-opt 0x000055e5bb21f2e7
1 triton-shared-opt 0x000055e5bb21ce0e
2 triton-shared-opt 0x000055e5bb21f99f
3 libc.so.6 0x00007ffad8065520
4 libc.so.6 0x00007ffad80b99fc pthread_kill + 300
5 libc.so.6 0x00007ffad8065476 raise + 22
6 libc.so.6 0x00007ffad804b7f3 abort + 211
7 triton-shared-opt 0x000055e5bb1d84a1
8 triton-shared-opt 0x000055e5ba2038bb
9 triton-shared-opt 0x000055e5ba2017ff
10 triton-shared-opt 0x000055e5ba203608
11 triton-shared-opt 0x000055e5ba20168b
12 triton-shared-opt 0x000055e5ba2035ac
13 triton-shared-opt 0x000055e5ba202c3b
14 triton-shared-opt 0x000055e5ba203da8
15 triton-shared-opt 0x000055e5ba12f280
16 triton-shared-opt 0x000055e5b9f6775d
17 triton-shared-opt 0x000055e5bad5af10
18 triton-shared-opt 0x000055e5bad8c9a4
19 triton-shared-opt 0x000055e5bad894bf
20 triton-shared-opt 0x000055e5bad684d5
21 triton-shared-opt 0x000055e5bad5e104
22 triton-shared-opt 0x000055e5bad61403
23 triton-shared-opt 0x000055e5ba124689
24 triton-shared-opt 0x000055e5ba172c26
25 triton-shared-opt 0x000055e5ba1733c1
26 triton-shared-opt 0x000055e5ba1757fb
27 triton-shared-opt 0x000055e5ba16fb69
28 triton-shared-opt 0x000055e5ba16ed8d
29 triton-shared-opt 0x000055e5bb1a40d9
30 triton-shared-opt 0x000055e5ba16a27a
31 triton-shared-opt 0x000055e5ba16a781
32 triton-shared-opt 0x000055e5b8e5493b
33 libc.so.6 0x00007ffad804cd90
34 libc.so.6 0x00007ffad804ce40 __libc_start_main + 128
35 triton-shared-opt 0x000055e5b8e54805
Traceback (most recent call last):
File "/workspace/hongjing/triton-shared2/python/examples/divsi.py", line 18, in <module>
ret = triton.compile(triton_, signature="*fp32,i32", constants={"XBLOCK": 32}, device_type="cpu")
File "/workspace/hongjing/triton2/python/triton/compiler/compiler.py", line 527, in compile
next_module = compile_kernel(module)
File "/workspace/hongjing/triton2/python/triton/third_party/cpu/__init__.py", line 287, in <lambda>
lambda src: _optimize_ttsharedir(_ttir_to_ttsharedir(src)))
File "/workspace/hongjing/triton2/python/triton/third_party/cpu/__init__.py", line 34, in _ttir_to_ttsharedir
subprocess.check_call([triton_shared_opt_path, src_path, "--triton-to-linalg", "-o", dst_path])
File "/usr/lib/python3.10/subprocess.py", line 369, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/workspace/hongjing/triton2/python/build/cmake.linux-x86_64-cpython-3.10/third_party/triton_shared/tools/triton-shared-opt/triton-shared-opt', '/tmp/tmpoizjes5t/tt.mlir', '--triton-to-linalg', '-o', '/tmp/tmpoizjes5t/ttshared.mlir']' died with <Signals.SIGABRT: 6>.
Currently I am looking at this issue, please assign it to me.
@yuanfz98 @nhat-nguyen I wanted to check on this. There's some discussion of this pattern being supported in #62. Does it work in the latest version?
@yuanfz98 @nhat-nguyen, I want to follow up on this as well. @blaine-rister, the issue still exists in the latest version.
@fhossein-quic Sorry for the delayed response. We can't statically determine the shape of the memory loads when there are div
ops in the pointer arithmetic sequence. @haishanzzzz and his team at Meta have begun to think about a potential fallback mode for these cases.
Was this triton kernel generated by Torch Inductor?
@fhossein-quic @nhat-nguyen I thought you might be interested in https://github.com/pytorch/pytorch/issues/125077. We are beefing up Inductor's codegen so it won't use mod/div to compute indices nearly as often.
@blaine-rister This code is indeed generated from torch-inductor! Thank you for the link and appreciate the improvements in the codegen! Do you know if the improvements will help in this case https://github.com/microsoft/triton-shared/discussions/138? We have rather complex codegen from torch-inductor throughout most of the basic operations such as singleton broadcasting, reshape,... I will definitely go through your chain of PR this week to understand more about the improvements. Thanks again!
@nhat-nguyen that PR should work for broadcasts. I'm not sure about reshape--I think it mostly depends on what you do with the result of the reshape. AFAIK reshape by itself doesn't change the underlying data, but it affects the semantics of subsequent operations on that data.
The basic pattern seems similar to what's described in #138. Basically, we pattern match on mod/div indexing expressions, trying to determine that this is the same iteration order as some ND block. Then we solve for the shape of that block.
There are some complicated scenarios where this analysis fails, e.g. torch.tile
.
Right now, to take advantage of that PR you have to follow several restrictions:
config.triton.use_block_ptr=True
)Stick to certain shapes (powers of 2, multiples of the maximum block size, a few other cases)
We have some ideas on how to expand it beyond those so you'll see the benefits on all shapes, with or without block pointers.
created from #7.
I don't know what original Triton code looked like that created this, but there is a division in the address expression.
repros.zip
triton-shared-opt -triton-to-linalg 5.mlir
triton-shared-opt -triton-to-linalg 32.mlir
triton-shared-opt -triton-to-linalg 35.mlir
triton-shared-opt -triton-to-linalg 41.mlir
triton-shared-opt -triton-to-linalg 72.mlir
triton-shared-opt -triton-to-linalg 88.mlir
Error output: