apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.76k stars 3.47k forks source link

[Bug] Init block not discoverable after sch.blockize #16889

Open nautasolva opened 6 months ago

nautasolva commented 6 months ago

When used on a block with a init statement, blockize creates a separate init block that is not discoverable by any means. This hinders further scheduling, like tensorizing the init block.

Expected behavior

When using blockize on a loop that contains an init statement, the init is moved to a new block <block>_init that should be discoverable with get_block or get_children_blocks on the newly created outer block.

Actual behavior

Init block exists in the TIR module but does not seem to be registered by the schedule. get_block("<block>_init>") fails with InternalError: Check failed: (it != self_->stmt2ref.end()) is false

Stacktrace

Traceback (most recent call last): File "/home/dev/tvm_upstream/../tvm/playground/blockize_init_bug.py", line 31, in a_init = sch.get_block("A_init") File "/home/dev/tvm_upstream/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap return func(*args, **kwargs) File "/home/dev/tvm_upstream/python/tvm/tir/schedule/schedule.py", line 499, in get_block return _ffi_api.ScheduleGetBlock( # type: ignore # pylint: disable=no-member File "/home/dev/tvm_upstream/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__ raise_last_ffi_error() File "/home/dev/tvm_upstream/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error raise py_err File "/home/dev/tvm_upstream/src/tir/schedule/traced_schedule.cc", line 128, in tvm::tir::TracedScheduleNode::GetBlock(tvm::runtime::String const&, tvm::runtime::Optional const&) BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name); File "/home/dev/tvm_upstream/src/tir/schedule/concrete_schedule.cc", line 321, in tvm::tir::ConcreteScheduleNode::GetBlock(tvm::runtime::String const&, tvm::runtime::Optional const&) Array blocks = tir::GetBlocks(this->state_, name, gv); File "/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 46, in tvm::tir::GetBlocks(tvm::tir::ScheduleState const&, tvm::runtime::String const&, tvm::GlobalVar const&) finder(prim_func->body); File "/home/dev/tvm_upstream/src/tir/ir/stmt_functor.cc", line 142, in tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::BlockNode const*) this->VisitStmt(op->init.value()); File "/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 29, in VisitStmt_ void VisitStmt_(const BlockNode* block) override { File "/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 32, in VisitStmt_ ICHECK(it != self_->stmt2ref.end()); tvm.error.InternalError: Traceback (most recent call last): 5: tvm::tir::TracedScheduleNode::GetBlock(tvm::runtime::String const&, tvm::runtime::Optional const&) at /home/dev/tvm_upstream/src/tir/schedule/traced_schedule.cc:128 4: tvm::tir::ConcreteScheduleNode::GetBlock(tvm::runtime::String const&, tvm::runtime::Optional const&) at /home/dev/tvm_upstream/src/tir/schedule/concrete_schedule.cc:321 3: tvm::tir::GetBlocks(tvm::tir::ScheduleState const&, tvm::runtime::String const&, tvm::GlobalVar const&) at /home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:46 2: tvm::tir::StmtVisitor::VisitStmt_(tvm::tir::BlockNode const*) at /home/dev/tvm_upstream/src/tir/ir/stmt_functor.cc:142 1: VisitStmt_ at /home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:29 0: VisitStmt_ at /home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc:32 File "/home/dev/tvm_upstream/src/tir/schedule/primitive/get_block_loop.cc", line 32

Environment

Reproducible on main (d4056ca79571d4265a12beeedd1b1565953df936)

Steps to reproduce

import tvm

from tvm.script import ir as I
from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main():
        # with T.block("root"):
        A_sum = T.alloc_buffer((1,), "float32")
        A = T.alloc_buffer((1, 16), "float32")
        for nn, ff in T.grid(1, 16):
            with T.block("A"):
                v_nn, v_ff = T.axis.remap("SR", [nn, ff])
                T.reads(A[v_nn, v_ff])
                T.writes(A_sum[v_nn])
                with T.init():
                    A_sum[v_nn] = T.float32(0)
                A_sum[v_nn] = A_sum[v_nn] + A[v_nn, v_ff]

sch = tvm.tir.Schedule(Module)

a = sch.get_block("A")

loop_n, loop_f = sch.get_loops(a)
sch.blockize(loop_f)

print(sch.mod) # <-- A_init exists

a_init = sch.get_block("A_init") # <-- fails with InternalError: Check failed: (it != self_->stmt2ref.end()) is false

Triage

patschmidt2 commented 6 months ago

I think you are supposed to call decompose_reductionbefore blockize:

import tvm

from tvm.script import ir as I
from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main():
        # with T.block("root"):
        A_sum = T.alloc_buffer((1,), "float32")
        A = T.alloc_buffer((1, 16), "float32")
        for nn, ff in T.grid(1, 16):
            with T.block("A"):
                v_nn, v_ff = T.axis.remap("SR", [nn, ff])
                T.reads(A[v_nn, v_ff])
                T.writes(A_sum[v_nn])
                with T.init():
                    A_sum[v_nn] = T.float32(0)
                A_sum[v_nn] = A_sum[v_nn] + A[v_nn, v_ff]

sch = tvm.tir.Schedule(Module)

a = sch.get_block("A")

loop_n, loop_f = sch.get_loops(a)

sch.decompose_reduction("A", loop_n)
sch.blockize(loop_f)

init_block = sch.get_block("A_init")

print(sch.mod) # <-- A_init exists
nautasolva commented 6 months ago

For my usage scenario I need to keep the T.init() statement so decompose_reduction is not an option. Also the fact that the A_init block is present in the associated module but not discoverable through schedule accessors clearly indicates a bug IMO.