apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.61k stars 3.44k forks source link

[Bug] Cannot add loops on top of the root block #17245

Open Cookiee235 opened 1 month ago

Cookiee235 commented 1 month ago

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/simple/bug_add_loop.py", line 51, in <module>
    mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_3tir9transform18DefaultGPUScheduleEvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  3: tvm::tir::transform::ThreadBind(tvm::tir::Schedule, tvm::tir::BlockRV const&, long, long)
  2: tvm::tir::TracedScheduleNode::AddUnitLoop(tvm::tir::BlockRV const&)
  1: tvm::tir::ConcreteScheduleNode::AddUnitLoop(tvm::tir::BlockRV const&)
  0: tvm::tir::AddUnitLoop(tvm::tir::ScheduleState, tvm::tir::StmtSRef)
  File "/software/tvm-lunder/src/tir/schedule/primitive/loop_transformation.cc", line 1153
ValueError: Check failed: (sref->parent != nullptr) is false: Cannot add loops on top of the root block

Steps to reproduce

import tvm
from tvm import relax

from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def min(v3_0: T.Buffer((T.int64(63), T.int64(1)), "float16"), v3_0_red: T.Buffer((T.int64(63),), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, k1 in T.grid(T.int64(63), T.int64(1)):
            with T.block("v3_0_red"):
                v_ax0, v_k1 = T.axis.remap("SR", [ax0, k1])
                T.reads(v3_0[v_ax0, v_k1])
                T.writes(v3_0_red[v_ax0])
                with T.init():
                    v3_0_red[v_ax0] = T.float16(65504)
                v3_0_red[v_ax0] = T.min(v3_0_red[v_ax0], v3_0[v_ax0, v_k1])

    @T.prim_func(private=True)
    def scatter_elements(var_x: T.handle, var_indices: T.handle, var_updates: T.handle, out_buf: T.Buffer((T.int64(4), T.int64(4)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        x = T.match_buffer(var_x, (T.int64(4), T.int64(4)), offset_factor=1)
        indices = T.match_buffer(var_indices, (T.int64(2), T.int64(2)), "int64", offset_factor=1)
        updates = T.match_buffer(var_updates, (T.int64(2), T.int64(2)), offset_factor=1)
        with T.block("scatter_elements_generic"):
            T.reads()
            T.writes()
            for i in T.parallel(T.int64(16)):
                out_buf[i // T.int64(4), i % T.int64(4)] = x[i // T.int64(4), i % T.int64(4)]
            for fused in T.parallel(T.int64(2)):
                for k in range(T.int64(2)):
                    out_buf[(fused * T.int64(4) + (indices[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)] + T.Cast("int64", indices[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)] < T.int64(0)) * T.int64(4))) // T.int64(4), (fused * T.int64(4) + (indices[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)] + T.Cast("int64", indices[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)] < T.int64(0)) * T.int64(4))) % T.int64(4)] = updates[(fused * T.int64(2) + k) // T.int64(2), (fused * T.int64(2) + k) % T.int64(2)]

    @R.function
    def main(v3_0: R.Tensor((63, 1), dtype="float16")) -> R.Tensor((4, 4), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.min, (v3_0,), out_sinfo=R.Tensor((63,), dtype="float16"))
            R.output(lv)
        return lv

mod = Module
#mod = tvm.relax.transform.DeadCodeElimination()(mod)
mod.show()
with tvm.target.Target("cuda"):
    mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
ex = relax.build(mod, target='cuda')

cc @Lunderberg @junrushao

Cookiee235 commented 1 month ago

For the dead code (i.e., def scatter_elements()...), if we keep it and the call mod = tvm.tir.transform.DefaultGPUSchedule()(mod), to execute the relax IR, the test case will crash unexpectedly. However, if we remove the dead function, the test case can run well.

@Lunderberg Can you help me review this issue? Thanks!

Lunderberg commented 1 month ago

Hmm. The PrimFunc definition is a bit odd. The present of the with T.block means that it is schedulable TIR, but there aren't any with T.block annotations inside the loops themselves. So the body looks like it is after the ConvertBlocksToOpaque transform, but DefaultGPUSchedule requires the annotations from before that.

The reason why it works when the dead function is removed is because DefaultGPUSchedule attempts to schedule all TIR functions, regardless of whether they are actually used.