apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.58k stars 3.43k forks source link

[Bug] [Relax] InternalError: Check failed: (*it).second == var #17200

Closed MellowArtisan closed 4 days ago

MellowArtisan commented 1 month ago

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/18_check.py", line 35, in <module>
    mod = tvm.transform.Sequential([relax.transform.LiftTransformParams(), relax.transform.LiftTransformParams()])(mod)  # crash here
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::PartitionTransformParams(tvm::runtime::Variant<tvm::Bool, tvm::runtime::Array<tvm::runtime::String, void> >)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::PartitionTransformParams(tvm::runtime::Variant<tvm::Bool, tvm::runtime::Array<tvm::runtime::String, void> >)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::relax::transform::PartitionTransformParams(tvm::runtime::Variant<tvm::Bool, tvm::runtime::Array<tvm::runtime::String, void> >)::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::IRModule, tvm::transform::PassContext) const [clone .isra.0]
  1: tvm::IRModuleNode::Add(tvm::GlobalVar const&, tvm::BaseFunc const&, bool)
  0: tvm::IRModuleNode::AddUnchecked(tvm::GlobalVar const&, tvm::BaseFunc const&)
  File "/software/tvm/src/ir/module.cc", line 233
InternalError: Check failed: (*it).second == var (I.GlobalVar("main_transform_params") vs. I.GlobalVar("main_transform_params")) :

Environment

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def matmul1(x: T.Buffer((T.int64(256), T.int64(256)), "float32"), w1_t: T.Buffer((T.int64(256), T.int64(256)), "float32"), matmul: T.Buffer((T.int64(256), T.int64(256)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for i0, i1, k in T.grid(T.int64(256), T.int64(256), T.int64(256)):
            with T.block("matmul"):
                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
                T.reads(x[v_i0, v_k], w1_t[v_k, v_i1])
                T.writes(matmul[v_i0, v_i1])
                with T.init():
                    matmul[v_i0, v_i1] = T.float32(0)
                matmul[v_i0, v_i1] = matmul[v_i0, v_i1] + x[v_i0, v_k] * w1_t[v_k, v_i1]

    @R.function
    def main(x: R.Tensor((256, 256), dtype="float32"), ) -> R.Tensor((256, 256), dtype="float32"):
        R.func_attr({"num_input": 1})
        cls = Module
        with R.dataflow():
            y1 = R.call_tir(cls.matmul1, (x, x), out_sinfo=R.Tensor((256, 256), dtype="float32"))
            R.output(y1)
        return y1

mod = Module
mod.show()
mod = tvm.relax.transform.LegalizeOps()(mod)
#mod = tvm.transform.Sequential([relax.transform.LiftTransformParams()])  # run well
mod = tvm.transform.Sequential([relax.transform.LiftTransformParams(), relax.transform.LiftTransformParams()])(mod)  # crash here

cc @junrushao

MellowArtisan commented 1 month ago

Hi all, I accidentally used the pass LiftTransformParams twice. However, the transformation failed unexpectedly! Can we set a pass multiple times in the tvm.transform.Sequential()? @Lunderberg @tqchen @junrushao @vinx13

Lunderberg commented 1 month ago

Looks like each application of LiftTransformParams attempts to make a function named "main_transform_params". This causes an error on the second application, because the names of each function need to be unique.

I think it would be good to have LiftTransformParams be idempotent, to prevent this type of error from occurring. This would require two main changes:

  1. When LiftTransformParams lifts out a new function, it would check whether a function with that same name already exists. If so, the two transformations would be composed together, producing a function that is equivalent to running new_params = new_transform(old_transform(params)). If LiftTransformParams is applied twice in a row, new_transform would be the identity function. This would also be useful when LiftTransformParams has been applied, but additional pre-processing steps have been generated by later optimizations.
  2. LiftTransformParams would not remove R.builtin.stop_lift_params. This would instead be removed with the FLowerBuiltin attribute (once https://github.com/apache/tvm/pull/17145 lands). This would ensure that a second application of LiftTransformParams follows the same restrictions as the first LiftTransformParams.
MellowArtisan commented 1 month ago

@Lunderberg Thank you very much for your detailed explanation and repair plan!

Lunderberg commented 1 week ago

Circling back to this issue, it should be resolved with PR #17314.