apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.66k stars 3.45k forks source link

[Bug] [Relax] InternalError: Check failed: (it_group != obj2group_.end()) is false #17210

Open Cookiee235 opened 2 months ago

Cookiee235 commented 2 months ago

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/bugs/3_simple.py", line 36, in <module>
    mod = relax.transform.MergeCompositeFunctions()(mod)  # crash here
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  17: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  16: tvm::transform::Pass::operator()(tvm::IRModule) const
  15: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  14: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  13: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::MergeCompositeFunctions()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::MergeCompositeFunctions()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  12: tvm::relax::MergeCompositeFunctions(tvm::IRModule)
  11: tvm::relax::MakeGroupedFunctions(tvm::IRModule, std::unordered_map<tvm::runtime::Object const*, tvm::relay::GraphPartitioner::Group*, std::hash<tvm::runtime::Object const*>, std::equal_to<tvm::runtime::Object const*>, std::allocator<std::pair<tvm::runtime::Object const* const, tvm::relay::GraphPartitioner::Group*> > > const&, bool, tvm::runtime::Array<tvm::runtime::String, void> const&)
  10: tvm::relax::OperatorFusor::Transform(tvm::runtime::Array<tvm::runtime::String, void> const&)
  9: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  8: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  7: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  6: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  5: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  4: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  3: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  2: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  1: tvm::relax::OperatorFusor::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*)
  0: tvm::relax::OperatorFusor::CollectFuncBindings(tvm::runtime::Array<tvm::relax::Binding, void> const&)
  File "/software/tvm-lunder/src/relax/transform/fuse_ops.cc", line 941
InternalError: Check failed: (it_group != obj2group_.end()) is false: Variable gv could not be found in any group

Environment

Any environment details, such as: Operating System, TVM version, etc

Steps to reproduce

from tvm import relax
from tvm.script import ir as I
from tvm.script import relax as R

@I.ir_module
class Module:
    @R.function(private=True)
    def fused_relax_nn_conv2d_relax_nn_relu(data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
        R.func_attr({"Composite": "dnnl.conv2d_relu", "Primitive": 1})
        cls = Module
        with R.dataflow():
            lv1 = R.nn.conv2d(data1, weight11)
            gv1 = R.nn.relu(lv1)
            R.output(gv1)
        return gv1

    @R.function(private=False)
    def main2(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1)
            R.output(gv)
        return gv

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), weight2: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((1, 64, 54, 54), dtype="float32"):
        cls = Module
        with R.dataflow():
            gv22: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1)
            R.output(gv22)
        return gv22

mod = Module
mod.show()
mod = relax.transform.MergeCompositeFunctions()(mod)  # crash here

cc @Lunderberg @junrushao @tqchen

Cookiee235 commented 2 months ago

@Lunderberg Could you help me analyze if this crash is a TVM bug? Thank you very much!

Lunderberg commented 2 months ago

Looks like it is a bug in TVM. The CompositeGroupBuilder is only called for the "main" function, but the later MakeGroupedFunctions is run on all Relax functions in the module. As a result, information about variables in the "main2" function is absent, which raises an error.

I don't have a fix at the moment, but it should be to have both the CompositeGroupBuilder and the MakeGroupedFunctions steps operate on the exact same list of functions. While the crash itself could be avoided by passing {"main"} to MakeGroupedFunctions, that still wouldn't have the correct behavior for modules that have multiple Relax functions. (Or have their main function named something other than "main".) Instead, we should collect all Relax functions in the module that have neither the kComposite nor the kCodegen attributes, and then run both CompositeGroupBuilder and MakeGroupedFunctions across that list of functions.

Cookiee235 commented 2 months ago

@Lunderberg Thank you for your deep analysis!

Lunderberg commented 2 months ago

@Cookiee235 Can you test with https://github.com/apache/tvm/pull/17212 applied, to verify that it has resolved the issue?

Cookiee235 commented 2 months ago

@Lunderberg Thank you for your quick fix. A new crash appears under the patch (https://github.com/apache/tvm/pull/17212).

Traceback (most recent call last):
  File "//test.py", line 35, in <module>
    mod = relax.transform.MergeCompositeFunctions()(mod)  # crash here
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  4: tvm::transform::Pass::operator()(tvm::IRModule) const
  3: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  1: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform23MergeCompositeFunctionsEvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
  0: tvm::relax::MergeCompositeFunctions(tvm::IRModule)
  File "/software/tvm-lunder/src/relax/transform/merge_composite_functions.cc", line 405
InternalError: Check failed: (!group_map.count(obj)) is false:
Lunderberg commented 2 months ago

Ooh, that ended up being a real tricky one. There was an implicit assumption in my first implementation that every group would contain distinct objects. This is usually true, because each relax::Var must be unique across an IRModule, and so the expressions containing relax::Var must also be unique in each function. However, relax expressions do not depend on any variables, such as static shapes, may re-use the same underlying C++ object across multiple functions.

In the test case, both main and main2 inferred the return type of fused_relax_nn_conv2d_relax_nn_relu, using the same ShapeExpr. This was then assigned to its own group during a PostOrderVisit, but that group was never actually used.

For now, I've added a fix to avoid over-collection of ShapeExpr, by removing the use of PostOrderVisit. The bug could still be triggered for shape expressions that occur explicitly within relax::Call arguments (e.g. for R.full arguments), and which are re-used across multiple functions in the same IRModule. That should be a rarer case, and since it will be a larger refactor to avoid those edge cases (to have caching on a per-Var basis, rather than a per-const Object* basis), I think I'll wait until that becomes an issue before tackling it.