Open Cookiee235 opened 2 months ago
@Lunderberg Could you help me analyze if this crash is a TVM bug? Thank you very much!
Looks like it is a bug in TVM. The CompositeGroupBuilder
is only called for the "main"
function, but the later MakeGroupedFunctions
is run on all Relax functions in the module. As a result, information about variables in the "main2"
function is absent, which raises an error.
I don't have a fix at the moment, but it should be to have both the CompositeGroupBuilder
and the MakeGroupedFunctions
steps operate on the exact same list of functions. While the crash itself could be avoided by passing {"main"}
to MakeGroupedFunctions
, that still wouldn't have the correct behavior for modules that have multiple Relax functions. (Or have their main function named something other than "main"
.) Instead, we should collect all Relax functions in the module that have neither the kComposite
nor the kCodegen
attributes, and then run both CompositeGroupBuilder
and MakeGroupedFunctions
across that list of functions.
@Lunderberg Thank you for your deep analysis!
@Cookiee235 Can you test with https://github.com/apache/tvm/pull/17212 applied, to verify that it has resolved the issue?
@Lunderberg Thank you for your quick fix. A new crash appears under the patch (https://github.com/apache/tvm/pull/17212).
Traceback (most recent call last):
File "//test.py", line 35, in <module>
mod = relax.transform.MergeCompositeFunctions()(mod) # crash here
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
raise_last_ffi_error()
File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
5: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
4: tvm::transform::Pass::operator()(tvm::IRModule) const
3: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
2: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
1: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform23MergeCompositeFunctionsEvEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SF_SJ_
0: tvm::relax::MergeCompositeFunctions(tvm::IRModule)
File "/software/tvm-lunder/src/relax/transform/merge_composite_functions.cc", line 405
InternalError: Check failed: (!group_map.count(obj)) is false:
Ooh, that ended up being a real tricky one. There was an implicit assumption in my first implementation that every group would contain distinct objects. This is usually true, because each relax::Var
must be unique across an IRModule
, and so the expressions containing relax::Var
must also be unique in each function. However, relax expressions do not depend on any variables, such as static shapes, may re-use the same underlying C++ object across multiple functions.
In the test case, both main
and main2
inferred the return type of fused_relax_nn_conv2d_relax_nn_relu
, using the same ShapeExpr
. This was then assigned to its own group during a PostOrderVisit
, but that group was never actually used.
For now, I've added a fix to avoid over-collection of ShapeExpr
, by removing the use of PostOrderVisit
. The bug could still be triggered for shape expressions that occur explicitly within relax::Call
arguments (e.g. for R.full
arguments), and which are re-used across multiple functions in the same IRModule
. That should be a rarer case, and since it will be a larger refactor to avoid those edge cases (to have caching on a per-Var
basis, rather than a per-const Object*
basis), I think I'll wait until that becomes an issue before tackling it.
Actual behavior
Environment
Any environment details, such as: Operating System, TVM version, etc
Steps to reproduce
cc @Lunderberg @junrushao @tqchen