apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.68k stars 3.45k forks source link

[Bug] [Relax] Argument type mismatch: expected R.Tensor, given R.Tuple #17223

Open Cookiee235 opened 2 months ago

Cookiee235 commented 2 months ago

It seems the provided Relax IRs are valid, however, it crashed when was compiled using relax.build() unexpectedly.

Actual behavior

Traceback (most recent call last):
  File "test_simp.py", line 26, in <module>
    ex = relax.build(mod, target='llvm')  # crash here!
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm-lunder/python/tvm/relax/pipeline.py", line 101, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  33: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  32: tvm::transform::Pass::operator()(tvm::IRModule) const
  31: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  30: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  29: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  28: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  27: _ZN3tvm7runtime13PackedFuncObj
  26: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  25: tvm::relax::CallTIRMutator::Run()
  24: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  23: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  22: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  21: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  20: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  19: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  18: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  17: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  16: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  15: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  14: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  13: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::ConstantNode const*)
  12: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  11: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  10: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
  9: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, tvm::runtime::String)
  8: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, bool, tvm::runtime::String)
  7: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
  6: tvm::relax::Normalizer::VisitExpr(tvm::RelayExpr const&)
  5: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
  2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
  1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
  0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
  File "/software/tvm-lunder/src/relax/ir/block_builder.cc", line 158
TVMError: Argument 0 type mismatch: expected R.Tensor((16,), dtype="float32"), given R.Tuple(R.Tensor((16,), dtype="float32"))

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module(check_well_formed=True)
class Module:
    @T.prim_func(private=True)
    def multiply_by_two(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
        for i in range(16):
            B[i] = A[i] * T.float32(2)

    @R.function
    def main(A: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
        cls = Module
        args: R.Tuple(R.Tensor((16,), dtype="float32")) = (A,)
        gv1 = R.call_tir(cls.multiply_by_two, (args,), out_sinfo=R.Tensor((16,), dtype="float32"))
        return gv1

mod = Module
mod.show()

mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm')  # crash here!

cc @Lunderberg @junrushao

Lunderberg commented 2 months ago

I can run the test case and reproduce the error, but the error message seems correct for the test case. The first argument to Module.multiply_by_two is a tensor, but the first item of R.call_tir's argument tuple is a tuple. This could be caught earlier by the well-formed checker, when updated to validate the R.call_tir arguments.

(As a side-note, replacing (args,) with args would have the correct struct info, but wouldn't be an in-line relax Tuple as required by R.call_tir. See the discussion in https://github.com/apache/tvm/pull/15916 for more detail on the requirement for an in-line tuple.)