apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.77k stars 3.47k forks source link

[Bug][Relax] Shape Mismatch for function argument #17310

Open Cookiee235 opened 2 months ago

Cookiee235 commented 2 months ago

Actual behavior

Traceback (most recent call last):
  File "/share_container/optfuzz/res/res_ut/res_executions/30_test.py", line 50, in <module>
    ex = relax.build(mod, target='llvm')
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
          ^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/software/tvm/python/tvm/relax/pipeline.py", line 101, in _pipeline
    mod = seq(mod)
          ^^^^^^^^
  File "/software/tvm/python/tvm/ir/transform.py", line 270, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  38: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  37: tvm::transform::Pass::operator()(tvm::IRModule) const
  36: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  35: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  34: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  33: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  32: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_1
  31: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  30: tvm::relax::CallTIRMutator::Run()
  29: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  28: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  27: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  26: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  25: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  24: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  23: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  22: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  21: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  20: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  19: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  18: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  17: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  16: _ZZN3tvm5relax11ExprMutator22InitVisitBindingVTabl
  15: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
  14: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  13: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  11: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
  10: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, tvm::runtime::String)
  9: tvm::relax::BlockBuilderImpl::Emit(tvm::RelayExpr, bool, tvm::runtime::String)
  8: tvm::relax::Normalizer::Normalize(tvm::RelayExpr const&)
  7: tvm::relax::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  6: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRef
  5: non-virtual thunk to tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  4: tvm::relax::Normalizer::VisitExpr_(tvm::relax::CallNode const*)
  3: tvm::relax::Normalizer::InferStructInfo(tvm::relax::Call const&)
  2: tvm::relax::DeriveCallRetStructInfo(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&, tvm::arith::Analyzer*)
  1: tvm::relax::CallRetStructInfoDeriver::Derive(tvm::relax::FuncStructInfo const&, tvm::relax::Call const&, tvm::relax::BlockBuilder const&)
  0: tvm::relax::BlockBuilderImpl::ReportFatal(tvm::Diagnostic const&)
  File "/software/tvm/src/relax/ir/block_builder.cc", line 159
TVMError: Argument 0 type mismatch: expected R.Tensor((64, 64, 56, 56), dtype="float32"), given R.Tensor((1, 64, 56, 56), dtype="float32")

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def conv2d(data: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"), weight1: T.Buffer((T.int64(64), T.int64(64), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        pad_temp = T.alloc_buffer((T.int64(1), T.int64(64), T.int64(58), T.int64(58)))
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(64), T.int64(58), T.int64(58)):
            with T.block("pad_temp"):
                v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
                T.reads(data[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
                T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
                pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(57) and T.int64(1) <= v_i3 and v_i3 < T.int64(57), data[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float32(0))
        for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(64), T.int64(56), T.int64(56), T.int64(64), T.int64(3), T.int64(3)):
            with T.block("conv2d_nchw"):
                v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
                T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], weight1[v_ff, v_rc, v_ry, v_rx])
                T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx])
                with T.init():
                    conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float32(0)
                conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * weight1[v_ff, v_rc, v_ry, v_rx]

    @T.prim_func
    def relu(data: T.Buffer((64, 64, 56, 56), "float32"), out: T.Buffer((64, 64, 56, 56), "float32")):
        # with T.block("root"):
        for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56):
            with T.block("root"):
                i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(data[i, j, k, l])
                T.writes(out[i, j, k, l])
                out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0))

    @R.function
    def main(data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32")) -> R.Tensor((64, 64, 56, 56), dtype="float32"):
        cls = Module
        with R.dataflow():
            conv1 = R.call_tir(cls.conv2d, (data, weight1), out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"))
            relu1 = R.call_tir(cls.relu, (conv1,), out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32"))
            R.output(relu1)
        return relu1

mod = Module
mod.show()
ex = relax.build(mod, target='llvm')

The given Relax IR passed the IR validity checking but threw a crash when I built it. Could you help me review it? Thanks a lot!

CC @Lunderberg @junrushao

xhmelon commented 1 month ago

Hi @Cookiee235 , The error is caused by a mismatch between the output shape of conv2d and the input shape of relu, which are (1, 64, 56, 56) and (64, 64, 56, 56), respectively. I changed the shape of relu from (64, 64, 56, 56) to (1, 64, 56, 56) and it is built successfully.

Cookiee235 commented 1 month ago

@xhmelon Thanks for your investigation. Indeed, the Realx IR is invalid and the crash message also gives the correct warning. However, the above Relax IR passes the verify_well_formed validation and lets us mistakenly consider the Relax IR valid! It will be better if we catch the exception early (i.e., crash in the mod = Module statement)!