tlc-pack / tvm-tensorir

Apache License 2.0
8 stars 0 forks source link

[BUG] 'xxx' is not bound to any variables #364

Open ZihengJiang opened 3 years ago

ZihengJiang commented 3 years ago

I found that the lowering procedure will fail while some itervars' extent is 1.

IThe IR is:

@tvm.script.tir
class Module:                                                                                                                                                                         
def main(a: ty.handle, b: ty.handle, c: ty.handle, M: ty.int32, N: ty.int32) -> None:
        C = tir.match_buffer(c, [M, N], elem_offset=0, align=128, offset_factor=1)
        A = tir.match_buffer(a, [M, 1024], elem_offset=0, align=128, offset_factor=1)                                                                                                     B = tir.match_buffer(b, [1024, N], elem_offset=0, align=128, offset_factor=1)
        # body
        with tir.block([], "root") as []:
            tir.reads([])
            tir.writes([])
            for i0_outer_outer_outer in tir.serial(0, M, annotation = {"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}):
                for i1_outer_outer_outer, i0_outer_outer_inner, i1_outer_outer_inner in tir.grid(((((((N + 1) - 1) + 1) - 1) + 1) - 1), 1, 1):
                    with tir.block([M, N], "matmul_init") as [vi_init, vj_init]:
                        tir.where((((((((((i0_outer_outer_outer + i0_outer_outer_inner) + i0_outer_inner) + i0_inner) < M) and (((i0_outer_outer_outer + i0_outer_outer_inner) + i
0_outer_inner) < M)) and ((i0_outer_outer_outer + i0_outer_outer_inner) < M)) and ((((i1_outer_outer_outer + i1_outer_outer_inner) + i1_outer_inner) + i1_inner) < N)) and (((i1_o
uter_outer_outer + i1_outer_outer_inner) + i1_outer_inner) < N)) and ((i1_outer_outer_outer + i1_outer_outer_inner) < N)))
                        tir.bind(vi_init, i0_outer_outer_outer)
                        tir.bind(vj_init, i1_outer_outer_outer)
                        tir.reads([])
                        tir.writes([C[vi_init:(vi_init + 1), vj_init:(vj_init + 1)]])
                        C[vi_init, vj_init] = tir.float32(0)
                    for i2_outer, i0_outer_inner, i1_outer_inner, i2_inner, i0_inner, i1_inner in tir.grid(1024, 1, 1, 1, 1, 1):
                        with tir.block([M, N, tir.reduce_axis(0, 1024)], "matmul_update") as [vi, vj, vk]:
                            tir.where((((((((((i0_outer_outer_outer + i0_outer_outer_inner) + i0_outer_inner) + i0_inner) < M) and (((i0_outer_outer_outer + i0_outer_outer_inner)
 + i0_outer_inner) < M)) and ((i0_outer_outer_outer + i0_outer_outer_inner) < M)) and ((((i1_outer_outer_outer +
No valid schedule found
 i1_outer_outer_inner) + i1_outer_inner) + i1_inner) < N)) and (((i1_outer_outer_outer + i1_outer_outer_inner) + i1_outer_inner) < N)) and ((i1_outer_outer_outer + i1_outer_outer
_inner) < N)))
                            tir.bind(vi, i0_outer_outer_outer)
                            tir.bind(vj, i1_outer_outer_outer)
                            tir.bind(vk, i2_outer)
                            tir.reads([C[vi:(vi + 1), vj:(vj + 1)], A[vi:(vi + 1), vk:(vk + 1)], B[vk:(vk + 1), vj:(vj + 1)]])
                            tir.writes([C[vi:(vi + 1), vj:(vj + 1)]])
                            C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vk, vj]))

Schedule is:

b1 = sch.get_block(name="matmul")
l2, l3, l4 = sch.get_axes(block=b1)
v5, v6, v7, v8 = sch.sample_perfect_tile(n=4, loop=l2, max_innermost_factor=16, decision=[-1, 1, 1, 1])
l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])
v13, v14, v15, v16 = sch.sample_perfect_tile(n=4, loop=l3, max_innermost_factor=16, decision=[-1, 1, 1, 1])
l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])
v21, v22 = sch.sample_perfect_tile(n=2, loop=l4, max_innermost_factor=16, decision=[1024, 1])
l23, l24 = sch.split(loop=l4, factors=[v21, v22])
sch.reorder(after_axes=[l9, l17, l10, l18, l23, l11, l19, l24, l12, l20])
b25 = sch.get_block(name="matmul")
sch.mark_block(block=b25, ann_key="auto_parallel_extent", ann_val=96)
sch.mark_block(block=b25, ann_key="auto_vectorize_extent", ann_val=32)
v26 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=2)
sch.mark_block(block=b25, ann_key="auto_unroll_explicit", ann_val=v26)
# Postprocessing
b27 = sch.get_block(name="matmul")
l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_axes(block=b27)
sch.mark_loop(loop=l28, ann_key="pragma_auto_unroll_max_step", ann_val=64)
sch.mark_loop(loop=l28, ann_key="pragma_unroll_explicit", ann_val=1)
b38 = sch.get_block(name="matmul")
l39, l40, l41, l42, l43, l44, l45, l46, l47, l48 = sch.get_axes(block=b38)
b49 = sch.decompose_reduction(block=b38, loop=l43)

Error message:

  [bt] (6) /home/ziheng/projects/auto-dyn/build/libtvm.so(TVMFuncCall+0x5b) [0x7f3507bbe8ab]
  [bt] (5) /home/ziheng/projects/auto-dyn/build/libtvm.so(+0x106b70f) [0x7f3506e2d70f]
  [bt] (4) /home/ziheng/projects/auto-dyn/build/libtvm.so(tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x313) [0x7f3506e2c4
73]
  [bt] (3) /home/ziheng/projects/auto-dyn/build/libtvm.so(tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const+0x1b7) [0x7f3506e2ce
a7]
  [bt] (2) /home/ziheng/projects/auto-dyn/build/libtvm.so(+0x15f792c) [0x7f35073b992c]
  [bt] (1) /home/ziheng/projects/auto-dyn/build/libtvm.so(tvm::tir::MakePackedAPI(tvm::tir::PrimFunc&&, int)+0x42db) [0x7f35073b6ecb]
  [bt] (0) /home/ziheng/projects/auto-dyn/build/libtvm.so(+0x15f0302) [0x7f35073b2302]
  File "/home/ziheng/projects/auto-dyn/src/tir/transforms/make_packed_api.cc", line 273
TVMError: Not all Vars are passed in api_args:  'i0_outer_inner'  'i0_inner'  'i1_outer_inner'  'i1_inner'  is not bound to any variables
junrushao commented 3 years ago

@Hzfengsy It is a bit weird that we have an error thrown in the lowered TIR. Would you like to take a look if it is related to our buffer flatten work?

junrushao commented 3 years ago

@ZihengJiang would you like to provide the IR being scheduled? Thanks a lot!

ZihengJiang commented 3 years ago

Here it is:

@tvm.script.tir
def dyn_mm(a: ty.handle, b: ty.handle, c: ty.handle, M: ty.int32, N: ty.int32) -> None:
    A = tir.match_buffer(a, (M, 1024), "float32")
    B = tir.match_buffer(b, (1024, N), "float32")
    C = tir.match_buffer(c, (M, N), "float32")
    with tir.block([M, N, tir.reduce_axis(0, 1024)], "matmul") as [vi, vj, vk]:
        with tir.init():
            C[vi, vj] = 0.0
        C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
Hzfengsy commented 3 years ago

It is a bug about predicate. Looking at the scheduled tir, i0_outer_inner is used at the first block_realize but defined after it. I admit that we did not consider carefully about the predicate during scheduling. Would be great if you can find the exact primitive that cause the problem. Sorry for bringing troubles.

tqchen commented 3 years ago

looking again, I think this is due to the reduction split, we will need to detect the predicates that related to the loops of the init and remove the predicates that touches the reduction var(which is not in the init). if there are predicate that touches both, we cannot do reduction split.

junrushao commented 3 years ago

Is this bug fixed on mainline?