Open nautasolva opened 6 months ago
I think you are supposed to call decompose_reduction
before blockize
:
import tvm
from tvm.script import ir as I
from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main():
# with T.block("root"):
A_sum = T.alloc_buffer((1,), "float32")
A = T.alloc_buffer((1, 16), "float32")
for nn, ff in T.grid(1, 16):
with T.block("A"):
v_nn, v_ff = T.axis.remap("SR", [nn, ff])
T.reads(A[v_nn, v_ff])
T.writes(A_sum[v_nn])
with T.init():
A_sum[v_nn] = T.float32(0)
A_sum[v_nn] = A_sum[v_nn] + A[v_nn, v_ff]
sch = tvm.tir.Schedule(Module)
a = sch.get_block("A")
loop_n, loop_f = sch.get_loops(a)
sch.decompose_reduction("A", loop_n)
sch.blockize(loop_f)
init_block = sch.get_block("A_init")
print(sch.mod) # <-- A_init exists
For my usage scenario I need to keep the T.init()
statement so decompose_reduction
is not an option. Also the fact that the A_init
block is present in the associated module but not discoverable through schedule accessors clearly indicates a bug IMO.
When used on a block with a init statement, blockize creates a separate init block that is not discoverable by any means. This hinders further scheduling, like tensorizing the init block.
Expected behavior
When using
blockize
on a loop that contains an init statement, the init is moved to a new block<block>_init
that should be discoverable withget_block
orget_children_blocks
on the newly created outer block.Actual behavior
Init block exists in the TIR module but does not seem to be registered by the schedule.
get_block("<block>_init>")
fails withInternalError: Check failed: (it != self_->stmt2ref.end()) is false
Stacktrace
Traceback (most recent call last): File "/home/dev/tvm_upstream/../tvm/playground/blockize_init_bug.py", line 31, in
a_init = sch.get_block("A_init")
File "/home/dev/tvm_upstream/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
return func(*args, **kwargs)
File "/home/dev/tvm_upstream/python/tvm/tir/schedule/schedule.py", line 499, in get_block
return _ffi_api.ScheduleGetBlock( # type: ignore # pylint: disable=no-member
File "/home/dev/tvm_upstream/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/home/dev/tvm_upstream/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
File "/home/dev/tvm_upstream/src/tir/schedule/traced_schedule.cc", line 128, in tvm::tir::TracedScheduleNode::GetBlock(tvm::runtime::String const&, tvm::runtime::Optional const&)
BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name);
File "/home/dev/tvm_upstream/src/tir/schedule/concrete_schedule.cc", line 321, in tvm::tir::ConcreteScheduleNode::GetBlock(tvm::runtime::String const&, tvm::runtime::Optional const&)
Array blocks = tir::GetBlocks(this->state_, name, gv);
File "/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 46, in tvm::tir::GetBlocks(tvm::tir::ScheduleState const&, tvm::runtime::String const&, tvm::GlobalVar const&)
finder(prim_func->body);
File "/home/dev/tvm_upstream/src/tir/ir/stmt_functor.cc", line 142, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::BlockNode const*)
this->VisitStmt(op->init.value());
File "/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 29, in VisitStmt_
void VisitStmt_(const BlockNode* block) override {
File "/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 32, in VisitStmt_
ICHECK(it != self_->stmt2ref.end());
tvm.error.InternalError: Traceback (most recent call last):
5: tvm::tir::TracedScheduleNode::GetBlock(tvm::runtime::String const&, tvm::runtime::Optional const&)
at /home/dev/tvm_upstream/src/tir/schedule/traced_schedule.cc:128
4: tvm::tir::ConcreteScheduleNode::GetBlock(tvm::runtime::String const&, tvm::runtime::Optional const&)
at /home/dev/tvm_upstream/src/tir/schedule/concrete_schedule.cc:321
3: tvm::tir::GetBlocks(tvm::tir::ScheduleState const&, tvm::runtime::String const&, tvm::GlobalVar const&)
at /home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:46
2: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::BlockNode const*)
at /home/dev/tvm_upstream/src/tir/ir/stmt_functor.cc:142
1: VisitStmt_
at /home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:29
0: VisitStmt_
at /home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:32
File "/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 32
Environment
Reproducible on main (d4056ca79571d4265a12beeedd1b1565953df936)
Steps to reproduce
Triage