apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.41k stars 3.4k forks source link

[Bug][Relay] cannot squeeze axis with dimension not equal to 1 at Keras frontend #14869

Open jikechao opened 1 year ago

jikechao commented 1 year ago

For the LSTM below, when batch_size != 1 (i.e., the size of first dimension input), compiling will lead to an unexpected crash and throw Check failed: *axis_ptr == 1 (2 vs. 1) : cannot squeeze axis with dimension not equal to 1 Notice that, when batch_size ==1, TVM can run well.

image

Question:

For the LSTM model, why batch_size !=1 leads to a crash?

Compile successfully

Actual behavior

Traceback (most recent call last):
  File "test.py", line 18, in <module>
    model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/interpreter.py", line 171, in evaluate
    return self._make_executor()
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 219, in _make_executor
    self.executable = compile(self.mod, self.target)
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 67, in compile
    compiler.lower(mod, target, target_host)
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 126, in lower
    self._lower(mod, raw_targets)
  File "/workplace/software/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  15: TVMFuncCall
  14: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  13: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
  12: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
  11: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  1: tvm::relay::TypeSolver::Solve()
  0: _ZN3tvm7runtime6detail
  19: TVMFuncCall
  18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  17: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
  16: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
  15: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
  14: tvm::transform::Pass::operator()(tvm::IRModule) const
  13: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  12: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  10: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  6: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  5: tvm::relay::TypeSolver::Solve()
  4: tvm::TypedEnvFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::operator()(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) const
  3: _ZN3tvm7runtime13Pac
  2: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  1: tvm::relay::SqueezeRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  0: _ZN3tvm7runtime6detail
  File "/workplace/software/tvm/tvm/src/relay/analysis/type_solver.cc", line 643
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [07:40:29] /workplace/software/tvm/tvm/src/relay/op/tensor/transform.cc:2340: 
---------------------------------------------------------------

  Check failed: *axis_ptr == 1 (2 vs. 1) : cannot squeeze axis with dimension not equal to 1

Steps to reproduce

import tvm
import tvm.relay as relay
from tensorflow import keras
from tensorflow.keras import layers, models

input_shape = (2, 3, 4)
x = layers.Input(shape=input_shape[1:], dtype='float32')

layer = keras.layers.LSTM(units=2)
layer.set_weights(layer.get_weights())

y = layer(x)
model = models.Model(x, y)
model.summary()
mod, params = relay.frontend.from_keras(model, {'input_1': input_shape})
print(mod)
with tvm.transform.PassContext(opt_level=3):
    model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()

Triage

cc @shingjan

jikechao commented 1 year ago

This is a bug about the LSTM layer. I'm not very familiar with the RNN network. For the LSTM model, why does batch_size !=1 lead to a crash in TVM?

@echuraev Could you give me some suggestions? Thank you in advance.

echuraev commented 1 year ago

@jikechao I suppose that the problem is the same as in #14868. In LSTM we have the same logic here as for Simple RNN.

jikechao commented 1 year ago

@echuraev Thanks for your explanation, I'm leaning to fix this bug.

echuraev commented 5 months ago

@jikechao I quickly took a look at this issue, and it looks like that one of the possible solution might be in using reshape instead of squeeze implementation in LSTM layer. The same trick I did in this PR: https://github.com/apache/tvm/pull/16526

jikechao commented 4 months ago

@jikechao I quickly took a look at this issue, and it looks like that one of the possible solution might be in using reshape instead of squeeze implementation in LSTM layer. The same trick I did in this PR: #16526

@echuraev, Thank you! I submitted a PR to fix it. Could you help me review it?