apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.6k stars 3.44k forks source link

[Bug] InternalError: non-normalized expression R.memory.kill_tensor(metadata["relax.expr.Constant"][0] #17340

Open Cookiee235 opened 1 week ago

Cookiee235 commented 1 week ago

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/reduced/complete/328_test.py", line 162, in <module>
    ex = relax.build(mod, target='llvm')
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  26: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  25: tvm::transform::Pass::operator()(tvm::IRModule) const
  24: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  23: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  22: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  21: tvm::relax::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  20: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  19: tvm::runtime::TypedPackedFunc<tvm::relax::Function (tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::KillAfterLastUse()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::KillAfterLastUse()::{lambda(tvm::relax::Function, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  18: tvm::relax::KillAfterLastUse(tvm::RelayExpr)
  17: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  16: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  15: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  14: tvm::relax::KillInserter::VisitExpr_(tvm::relax::FunctionNode const*)
  13: tvm::relax::CollectLastUsage::Collect(tvm::RelayExpr const&)
  12: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  11: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  10: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
  9: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  8: _ZZN3tvm5relax11ExprFunctorIFvRKNS_9RelayExp
  7: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
  6: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
  5: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  4: tvm::relax::CollectLastUsage::VisitBinding(tvm::relax::Binding const&)
  3: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
  2: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode const*)
  1: _ZZN3tvm5relax11ExprVisitor22InitVisitBindingVTabl
  0: tvm::relax::CollectLastUsage::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
  File "/software/tvm/src/relax/transform/kill_after_last_use.cc", line 171
InternalError: Check failed: (killed_object) is false: Internal error: non-normalized expression R.memory.kill_tensor(metadata["relax.expr.Constant"][0])

Steps to reproduce

reproducible script ``` import tvm from tvm import relax metadata = tvm.ir.load_json("""{ \"root\": 1, \"nodes\": [ { \"type_key\": \"\" }, { \"type_key\": \"Map\", \"keys\": [ \"relax.expr.Constant\" ], \"data\": [2] }, { \"type_key\": \"Array\", \"data\": [3] }, { \"type_key\": \"relax.expr.Constant\", \"attrs\": { \"_checked_type_\": \"11\", \"data\": \"0\", \"span\": \"0\", \"struct_info_\": \"4\" } }, { \"type_key\": \"relax.TensorStructInfo\", \"attrs\": { \"dtype\": \"float32\", \"ndim\": \"2\", \"shape\": \"5\", \"span\": \"0\", \"vdevice\": \"0\" } }, { \"type_key\": \"relax.expr.ShapeExpr\", \"attrs\": { \"_checked_type_\": \"10\", \"span\": \"0\", \"struct_info_\": \"9\", \"values\": \"6\" } }, { \"type_key\": \"Array\", \"data\": [7, 8] }, { \"type_key\": \"IntImm\", \"attrs\": { \"dtype\": \"int64\", \"span\": \"0\", \"value\": \"16\" } }, { \"type_key\": \"IntImm\", \"attrs\": { \"dtype\": \"int64\", \"span\": \"0\", \"value\": \"16\" } }, { \"type_key\": \"relax.ShapeStructInfo\", \"attrs\": { \"ndim\": \"2\", \"span\": \"0\", \"values\": \"6\" } }, { \"type_key\": \"relax.ShapeType\", \"attrs\": { \"ndim\": \"2\", \"span\": \"0\" } }, { \"type_key\": \"relax.DynTensorType\", \"attrs\": { \"dtype\": \"float32\", \"ndim\": \"2\", \"span\": \"0\" } } ], \"b64ndarrays\": [ \"P6G0lvBAXt0AAAAAAAAAAAEAAAAAAAAAAgAAAAIgAQAQAAAAAAAAABAAAAAAAAAAAAQAAAAAAAAAAAAAAACAPwAAAEAAAEBAAACAQAAAoEAAAMBAAADgQAAAAEEAABBBAAAgQQAAMEEAAEBBAABQQQAAYEEAAHBBAACAQQAAiEEAAJBBAACYQQAAoEEAAKhBAACwQQAAuEEAAMBBAADIQQAA0EEAANhBAADgQQAA6EEAAPBBAAD4QQAAAEIAAARCAAAIQgAADEIAABBCAAAUQgAAGEIAABxCAAAgQgAAJEIAAChCAAAsQgAAMEIAADRCAAA4QgAAPEIAAEBCAABEQgAASEIAAExCAABQQgAAVEIAAFhCAABcQgAAYEIAAGRCAABoQgAAbEIAAHBCAAB0QgAAeEIAAHxCAACAQgAAgkIAAIRCAACGQgAAiEIAAIpCAACMQgAAjkIAAJBCAACSQgAAlEIAAJZCAACYQgAAmkIAAJxCAACeQgAAoEIAAKJCAACkQgAApkIAAKhCAACqQgAArEIAAK5CAACwQgAAskIAALRCAAC2QgAAuEIAALpCAAC8QgAAvkIAAMBCAADCQgAAxEIAAMZCAADIQgAAykIAAMxCAADOQgAA0EIAANJCAADUQgAA1kIAANhCAADaQgAA3EIAAN5CAADgQgAA4kIAAORCAADmQgAA6EIAAOpCAADsQgAA7kIAAPBCAADyQgAA9EIAAPZCAAD4QgAA+kIAAPxCAAD+QgAAAEMAAAFDAAACQwAAA0MAAARDAAAFQwAABkMAAAdDAAAIQwAACUMAAApDAAALQwAADEMAAA1DAAAOQwAAD0MAABBDAAARQwAAEkMAABNDAAAUQwAAFUMAABZDAAAXQwAAGEMAABlDAAAaQwAAG0MAABxDAAAdQwAAHkMAAB9DAAAgQwAAIUMAACJDAAAjQwAAJEMAACVDAAAmQwAAJ0MAAChDAAApQwAAKkMAACtDAAAsQwAALUMAAC5DAAAvQwAAMEMAADFDAAAyQwAAM0MAADRDAAA1QwAANkMAADdDAAA4QwAAOUMAADpDAAA7QwAAPEMAAD1DAAA+QwAAP0MAAEBDAABBQwAAQkMAAENDAABEQwAARUMAAEZDAABHQwAASEMAAElDAABKQwAAS0MAAExDAABNQwAATkMAAE9DAABQQwAAUUMAAFJDAABTQwAAVEMAAFVDAABWQwAAV0MAAFhDAABZQwAAWkMAAFtDAABcQwAAXUMAAF5DAABfQwAAYEMAAGFDAABiQwAAY0MAAGRDAABlQwAAZkMAAGdDAABoQwAAaUMAAGpDAABrQwAAbEMAAG1DAABuQwAAb0MAAHBDAABxQwAAckMAAHNDAAB0QwAAdUMAAHZDAAB3QwAAeEMAAHlDAAB6QwAAe0MAAHxDAAB9QwAAfkMAAH9D\" ], \"attrs\": {\"tvm_version\": \"0.17.dev0\"} }""") from tvm.script import ir as I from tvm.script import tir as T from tvm.script import relax as R @I.ir_module class Module: @T.prim_func def add(rxplaceholder: T.Buffer((T.int64(8),), "float32"), rxplaceholder_1: T.Buffer((), "float32"), T_add: T.Buffer((T.int64(8),), "float32")): T.evaluate(0) @T.prim_func(private=True) def add_2(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), B: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_add: T.Buffer((T.int64(16), T.int64(16)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): for ax0, ax1 in T.grid(T.int64(16), T.int64(16)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A[v_ax0, v_ax1], B[v_ax0, v_ax1]) T.writes(T_add[v_ax0, v_ax1]) T_add[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[v_ax0, v_ax1] @T.prim_func(private=True) def cast1(gv: T.Buffer((T.int64(16), T.int64(16)), "float32"), compute: T.Buffer((T.int64(16), T.int64(16)), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): for i0, i1 in T.grid(T.int64(16), T.int64(16)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) T.reads(gv[v_i0, v_i1]) T.writes(compute[v_i0, v_i1]) compute[v_i0, v_i1] = T.Cast("float16", gv[v_i0, v_i1]) @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): T.evaluate(0) @T.prim_func def log(rxplaceholder: T.Buffer((T.int64(10),), "float32"), compute: T.Buffer((T.int64(10),), "float32")): T.evaluate(0) @T.prim_func def pad(rxplaceholder: T.Buffer((T.int64(8),), "float32"), PadInput: T.Buffer((T.int64(10),), "float32")): T.evaluate(0) @T.prim_func def relu(rxplaceholder: T.Buffer((T.int64(8),), "float32"), compute: T.Buffer((T.int64(8),), "float32")): T.evaluate(0) @T.prim_func def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8),), "float32")): T.evaluate(0) @R.function def main() -> R.Tensor((16, 16), dtype="float16"): cls = Module with R.dataflow(): gv = R.call_tir(cls.add_2, (metadata["relax.expr.Constant"][0], metadata["relax.expr.Constant"][0]), out_sinfo=R.Tensor((16, 16), dtype="float32")) gv_1 = R.call_tir(cls.cast1, (gv,), out_sinfo=R.Tensor((16, 16), dtype="float16")) R.output(gv_1) return gv_1 mod = Module seq = tvm.transform.Sequential([relax.transform.KillAfterLastUse(), relax.transform.FoldConstant()]) # only this sequence can trigger the bug mod = seq(mod) ex = relax.build(mod, target='llvm') ```

CC @Lunderberg

Cookiee235 commented 1 week ago

This exception only can be caught when using the given sequence (i.e., [KillAfterLastUse(), FoldConstant()]). Is it legal to use pass KillAfterLastUse first before using the pass FoldConstant?

@Lunderberg Can you help check if this uncovers a bug? Thank you!

Lunderberg commented 1 week ago

Yup, this is definitely a bug.

If I instrument the passes with tvm.relax.ir.instrument.WellFormedInstrument, the output of KillAfterLastUse is ill-formed, because the impure functions used to drop the object are inserted into a dataflow block. This doesn't come up as an issue in normal use, because KillAfterLastUse is applied after both ToNonDataflow and RemovePurityChecking. However, I'd still consider this a bug in KillAfterLastUse, because every pass that is given well-formed IR should produce well-formed IR. (This isn't a smoking gun for the root cause, but ill-formed IR can cause downstream passes to make incorrect assumptions, and I could easily see that being the case here.)

I suspect that the fix for this will be to have KillAfterLastUse insert the R.memory.kill_tensor at the first legal location after the last usage, rather than at the first syntactically-allowed location after the last usage. Since the impure R.memory.kill_tensor call isn't allowed in a dataflow block, R.memory.kill_tensor can't be generated until after the dataflow block.