tlc-pack / tvm-tensorir

Apache License 2.0
8 stars 0 forks source link

[BUG] variable 'xxx' has been used before definition! #271

Open junrushao opened 3 years ago

junrushao commented 3 years ago

It happens when computing a block at a tile too deep inside, and we probably need more careful error messages. Per discussion with @spectrometerHBH, we think it is not an urgent bug, but needs to fix the error message in the future.

Example. On workload T2D.

Schedule.

b1 = sch.get_block(name="conv2d_transpose_nhwc")
l2, l3, l4, l5, l6, l7, l8 = sch.get_axes(block=b1)
v9, v10, v11, v12 = sch.sample_perfect_tile(n_splits=4, loop=l2, max_innermost_factor=16, decision=[1, 1, 1, 1])
l13, l14, l15, l16 = sch.split(loop=l2, factors=[v9, v10, v11, v12])
v17, v18, v19, v20 = sch.sample_perfect_tile(n_splits=4, loop=l3, max_innermost_factor=16, decision=[1, 4, 1, 2])
l21, l22, l23, l24 = sch.split(loop=l3, factors=[v17, v18, v19, v20])
v25, v26, v27, v28 = sch.sample_perfect_tile(n_splits=4, loop=l4, max_innermost_factor=16, decision=[1, 4, 1, 2])
l29, l30, l31, l32 = sch.split(loop=l4, factors=[v25, v26, v27, v28])
v33, v34, v35, v36 = sch.sample_perfect_tile(n_splits=4, loop=l5, max_innermost_factor=16, decision=[32, 4, 1, 2])
l37, l38, l39, l40 = sch.split(loop=l5, factors=[v33, v34, v35, v36])
v41, v42 = sch.sample_perfect_tile(n_splits=2, loop=l6, max_innermost_factor=16, decision=[4, 1])
l43, l44 = sch.split(loop=l6, factors=[v41, v42])
v45, v46 = sch.sample_perfect_tile(n_splits=2, loop=l7, max_innermost_factor=16, decision=[2, 2])
l47, l48 = sch.split(loop=l7, factors=[v45, v46])
v49, v50 = sch.sample_perfect_tile(n_splits=2, loop=l8, max_innermost_factor=16, decision=[32, 16])
l51, l52 = sch.split(loop=l8, factors=[v49, v50])
sch.reorder(after_axes=[l13, l21, l29, l37, l14, l22, l30, l38, l43, l47, l51, l15, l23, l31, l39, l44, l48, l52, l16, l24, l32, l40])
b53 = sch.get_block(name="PadInput")
sch.mark_block(block=b53, ann_key="auto_parallel_extent", ann_val=64)
sch.mark_block(block=b53, ann_key="auto_vectorize_extent", ann_val=32)
v54 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=3)
sch.mark_block(block=b53, ann_key="auto_unroll_explicit", ann_val=v54)
b55 = sch.get_block(name="conv2d_transpose_nhwc")
sch.mark_block(block=b55, ann_key="auto_parallel_extent", ann_val=64)
sch.mark_block(block=b55, ann_key="auto_vectorize_extent", ann_val=32)
v56 = sch.sample_categorical(candidates=[0, 16, 64, 512], probs=[0.25, 0.25, 0.25, 0.25], decision=2)
sch.mark_block(block=b55, ann_key="auto_unroll_explicit", ann_val=v56)
b57 = sch.get_block(name="PadInput")
b58, = sch.get_consumers(block=b57)
l59 = sch.sample_compute_location(block=b58, decision=10)
sch.compute_at(block=b57, loop=l59)
# Postprocessing
b60 = sch.get_block(name="PadInput")
l61, l62, l63, l64, l65, l66, l67, l68, l69, l70, l71, l72, l73, l74, l75 = sch.get_axes(block=b60)
l76 = sch.fuse(loops=[l61, l62, l63, l64, l65, l66])
sch.parallel(loop=l76)
l77 = sch.fuse(loops=[l74, l75])
sch.vectorize(loop=l77)
sch.mark_loop(loop=l76, ann_key="pragma_auto_unroll_max_step", ann_val=512)
sch.mark_loop(loop=l76, ann_key="pragma_unroll_explicit", ann_val=1)
b78 = sch.get_block(name="conv2d_transpose_nhwc")
l79, l80, l81, l82, l83, l84, l85, l86, l87, l88, l89, l90, l91, l92, l93, l94, l95 = sch.get_axes(block=b78)
l96 = sch.fuse(loops=[l92, l93, l94, l95])
sch.vectorize(loop=l96)
sch.mark_loop(loop=l79, ann_key="pragma_auto_unroll_max_step", ann_val=64)
sch.mark_loop(loop=l79, ann_key="pragma_unroll_explicit", ann_val=1)
b97 = sch.get_block(name="conv2d_transpose_nhwc")
l98, l99, l100, l101, l102, l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_axes(block=b97)
b112 = sch.decompose_reduction(block=b97, loop=l101)
l113, l114, l115, l116, l117, l118 = sch.get_axes(block=b112)
sch.vectorize(loop=l118)

TIR.

@tvm.script.tir
def func(var_inputs: ty.handle, var_weight: ty.handle, var_conv2d_transpose_nhwc: ty.handle) -> None:
    inputs = tir.match_buffer(var_inputs, [1, 4, 4, 512], elem_offset=0, align=128, offset_factor=1)
    weight = tir.match_buffer(var_weight, [4, 4, 512, 256], elem_offset=0, align=128, offset_factor=1)
    conv2d_transpose_nhwc = tir.match_buffer(var_conv2d_transpose_nhwc, [1, 8, 8, 256], elem_offset=0, align=128, offset_factor=1)
    # body
    with tir.block([], "root") as []:
        tir.reads([])
        tir.writes([])
        PadInput = tir.buffer_allocate([1, 6, 6, 512], elem_offset=0, align=128, offset_factor=1)
        for i0_outer_outer_outer_i1_outer_outer_outer_fused_i2_outer_outer_outer_fused_i3_outer_outer_outer_fused_i0_outer_outer_inner_fused_i1_outer_outer_inner_fused in range(0, 128, annotation = {"loop_type":"parallel", "pragma_auto_unroll_max_step":512, "pragma_unroll_explicit":1}):
            for i2_outer_outer_inner, i3_outer_outer_inner in tir.grid(4, 4):
                for i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused_init, i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused_init_1 in tir.grid(8, 8):
                    for i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused_init_2 in range(0, 8, annotation = {"loop_type":"vectorize"}):
                        with tir.block([1, 8, 8, 256], "conv2d_transpose_nhwc_init") as [n_init, h_init, w_init, co_init]:
                            tir.bind(n_init, 0)
                            tir.bind(h_init, ((tir.floormod(i0_outer_outer_outer_i1_outer_outer_outer_fused_i2_outer_outer_outer_fused_i3_outer_outer_outer_fused_i0_outer_outer_inner_fused_i1_outer_outer_inner_fused, 4)*2) + tir.floordiv(i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused_init_2, 4)))
                            tir.bind(w_init, ((i2_outer_outer_inner*2) + tir.floormod(tir.floordiv(i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused_init_2, 2), 2)))
                            tir.bind(co_init, (((tir.floordiv(i0_outer_outer_outer_i1_outer_outer_outer_fused_i2_outer_outer_outer_fused_i3_outer_outer_outer_fused_i0_outer_outer_inner_fused_i1_outer_outer_inner_fused, 4)*8) + (i3_outer_outer_inner*2)) + tir.floormod(i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused_init_2, 2)))
                            tir.reads([])
                            tir.writes([conv2d_transpose_nhwc[n_init:(n_init + 1), h_init:(h_init + 1), w_init:(w_init + 1), co_init:(co_init + 1)]])
                            conv2d_transpose_nhwc[n_init, h_init, w_init, co_init] = tir.float32(0)
                for i4_outer, i5_outer, i6_outer in tir.grid(4, 2, 32):
                    for ax0, ax1 in tir.grid(1, (tir.floordiv((tir.floormod(i4_outer, 2) + 1), 2) + 1)):
                        for ax2_ax3_fused in range(0, 32, annotation = {"loop_type":"vectorize"}):
                            with tir.block([1, 6, 6, 512], "PadInput") as [i0, i1, i2, i3]:
                                tir.bind(i0, 0)
                                tir.bind(i1, ((tir.floordiv(i4_outer, 2) + tir.floormod(i0_outer_outer_outer_i1_outer_outer_outer_fused_i2_outer_outer_outer_fused_i3_outer_outer_outer_fused_i0_outer_outer_inner_fused_i1_outer_outer_inner_fused, 4)) + ax1))
                                tir.bind(i2, ((i2_outer_outer_inner + i5_outer) + tir.floordiv(ax2_ax3_fused, 16)))
                                tir.bind(i3, ((i6_outer*16) + tir.floormod(ax2_ax3_fused, 16)))
                                tir.reads([inputs[i0:(i0 + 1), (i1 - 1):((i1 - 1) + 1), (i2 - 1):((i2 - 1) + 1), i3:(i3 + 1)]])
                                tir.writes([PadInput[i0:(i0 + 1), i1:(i1 + 1), i2:(i2 + 1), i3:(i3 + 1)]])
                                PadInput[i0, i1, i2, i3] = tir.if_then_else(((((1 <= i1) and (i1 < 5)) and (1 <= i2)) and (i2 < 5)), inputs[i0, (i1 - 1), (i2 - 1), i3], tir.float32(0), dtype="float32")
                    for i0_outer_inner, i1_outer_inner, i2_outer_inner, i3_outer_inner, i4_inner, i5_inner, i6_inner in tir.grid(1, 1, 1, 1, 1, 2, 16):
                        for i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused in range(0, 8, annotation = {"loop_type":"vectorize"}):
                            with tir.block([1, 8, 8, 256, tir.reduce_axis(0, 4), tir.reduce_axis(0, 4), tir.reduce_axis(0, 512)], "conv2d_transpose_nhwc_update") as [n, h, w, co, rh, rw, rc]:
                                tir.bind(n, 0)
                                tir.bind(h, ((tir.floormod(i0_outer_outer_outer_i1_outer_outer_outer_fused_i2_outer_outer_outer_fused_i3_outer_outer_outer_fused_i0_outer_outer_inner_fused_i1_outer_outer_inner_fused, 4)*2) + tir.floordiv(i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused, 4)))
                                tir.bind(w, ((i2_outer_outer_inner*2) + tir.floormod(tir.floordiv(i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused, 2), 2)))
                                tir.bind(co, (((tir.floordiv(i0_outer_outer_outer_i1_outer_outer_outer_fused_i2_outer_outer_outer_fused_i3_outer_outer_outer_fused_i0_outer_outer_inner_fused_i1_outer_outer_inner_fused, 4)*8) + (i3_outer_outer_inner*2)) + tir.floormod(i0_inner_i1_inner_fused_i2_inner_fused_i3_inner_fused, 2)))
                                tir.bind(rh, i4_outer)
                                tir.bind(rw, ((i5_outer*2) + i5_inner))
                                tir.bind(rc, ((i6_outer*16) + i6_inner))
                                tir.reads([conv2d_transpose_nhwc[n:(n + 1), h:(h + 1), w:(w + 1), co:(co + 1)], PadInput[n:(n + 1), tir.floordiv((h + rh), 2):(tir.floordiv((h + rh), 2) + 1), tir.floordiv((w + rw), 2):(tir.floordiv((w + rw), 2) + 1), rc:(rc + 1)], weight[(3 - rh):((3 - rh) + 1), (3 - rw):((3 - rw) + 1), rc:(rc + 1), co:(co + 1)]])
                                tir.writes([conv2d_transpose_nhwc[n:(n + 1), h:(h + 1), w:(w + 1), co:(co + 1)]])
                                conv2d_transpose_nhwc[n, h, w, co] = (conv2d_transpose_nhwc[n, h, w, co] + (tir.if_then_else(((tir.floormod((h + rh), 2) == 0) and (tir.floormod((w + rw), 2) == 0)), PadInput[n, tir.floordiv((h + rh), 2), tir.floordiv((w + rw), 2), rc], tir.float32(0), dtype="float32")*weight[(3 - rh), (3 - rw), rc, co]))

Error message.

Traceback (most recent call last):
  File "/home/jrshao/Projects/tvm-tensorir/python/tvm/meta_schedule/utils.py", line 434, in timed_func
    func = tvm_build(
  File "/home/jrshao/Projects/tvm-tensorir/python/tvm/driver/build_module.py", line 427, in build
    mod_host, mdev = _build_for_device(input_mod, tar, target_host)
  File "/home/jrshao/Projects/tvm-tensorir/python/tvm/driver/build_module.py", line 269, in _build_for_device
    mod_mixed = tvm.transform.Sequential(opt_mixed)(mod_mixed)
  File "/home/jrshao/Projects/tvm-tensorir/python/tvm/ir/transform.py", line 127, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 321, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 256, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 245, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (8) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)+0x167) [0x7f988f7105f7]
  [bt] (7) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#11}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)+0x28) [0x7f988f705748]
  [bt] (6) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(tvm::tir::StmtMutator::VisitStmt_(tvm::tir::SeqStmtNode const*)+0x115) [0x7f988fcee175]
  [bt] (5) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(void tvm::runtime::Array<tvm::tir::Stmt, void>::MutateByApply<tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1}>(tvm::tir::StmtMutator::Internal::Mutate(tvm::tir::StmtMutator*, tvm::runtime::Array<tvm::tir::Stmt, void> const&)::{lambda(tvm::tir::Stmt const&)#1})+0xb6) [0x7f988fcf4386]
  [bt] (4) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)+0x167) [0x7f988f7105f7]
  [bt] (3) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#4}::_FUN(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)+0x28) [0x7f988f705518]
  [bt] (2) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(tvm::tir::VarUseDefAnalysis::VisitStmt_(tvm::tir::ForNode const*)+0x2e) [0x7f988fe858fe]
  [bt] (1) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(tvm::tir::VarUseDefAnalysis::HandleDef(tvm::tir::VarNode const*)+0xbd) [0x7f988fe8577d]
  [bt] (0) /home/jrshao/Projects/tvm-tensorir/build/libtvm.so(+0xde94b8) [0x7f988fe814b8]
  File "/home/jrshao/Projects/tvm-tensorir/src/tir/transforms/split_host_device.cc", line 152
TVMError: Check failed: !use_count_.count(v): variable i4_outer has been used before definition!
spectrometerHBH commented 3 years ago

Note that if we use pure te schedule to reproduce this tir schedule, te will have same bug. Currently, the way to handle this is to reject compute at that will generate irregular loops.