Open PabloButron opened 2 months ago
Do you try another models, for example Gemma2?
Tried Tiny llama, but getting the same error, I think it is not about the model, because this models works fine with 6 gb of ram or MacOS, I even tried to quantize again the same models but doesn't matter if the quant is old or new works on pro max devices but not on iPad with 3 gb or iPhone with 4 gb, if I delete the "prefill_chunk_size": 64, in the config crashes because low RAM, before this setup was working suddenly it stops and now crash with previos log
Tried with Qwen 2.5 models and is crashing too, even with the 0.5 b model and q4f16_1 230 mb for the weights and "prefill_chunk_size": 64, "context_window_size": 2048
I am not sure it's a RAM issue as much as an issue with the model compilation for older iOS devices. The error is from the Metal compiler as it reports Compiler failed with XPC_ERROR_CONNECTION_INTERRUPTED
- what I'm starting to think now is that it's more of an issue with how the model was compiled in the first place...
I will add that the crash persists even with a 0.5B parameter model -
The Metal compilation error makes me think that somehow the TVM-generated Metal code can't be correctly compiled. This consistently happens on an A13 Bionic device (iPhone 11).
I tried compiling TVM, MLC with debug symbols and can see that the issue is a compiler internal errror
Going on a hunch, but I'm wondering whether one of the authors of https://github.com/mlc-ai/relax/blame/79a69ae4a92c9d4f23e62f93ce5b0d90ed29e5ed/src/runtime/metal/metal_module.mm#L100 could help us out? Please @echuraev @tqchen? 😄
I thought about it more and it's interesting that the function is a matrix multiplication operation specifically and others seem to work. This got me thinking that maybe the problem is with the generated Metal code. The issue is related to architecture differences between GPU families. Referencing the Metal Feature Set Tables
Page 1:
Page 4:
I initally checked recent changes to Metal intrsinsics and found that simd_shuffle_*
intrincs were added in commit 22ec541 by @MasterJH5574 and referencing setion 6.9.2 SIMD-Group Functions from the Metal Shading Language Spec. Commenting these intrinsics hasn't made the model work and this PR is over a year old.
After reading through the code, I have also come across the dlight
layer that is actually responsible for generating the Metal code and have come across commit c0abab7 by @Hzfengsy which enables SimdGroup ops for Metal which was submitted in June 24.
In particular, commenting out the MatMul invokation at and recompiling the model makes it work! 😄
Given that the bug is actually in the TVM layer and it's honestly not super clear if it's a "bug" per-se, or how the maintainers think about supporting older GPU architectures, I think it warrants a discussion.
From my perspective as a dev who want to ship an app built with MLC with a broad audience using a small model, I would disable this matrix multiplication optimization in favor of compatibility. However, presumably this change was made for a good reason, so perhaps a flag would be a good option?
Also, note that actually I needed to not just comment out the MatMul invokation but also comment out the intrinsics in intrin_rule_metal.cc
So really, both are required to fully get it working. I thought that I can leave the registered instrics in initially, but that caused another internal error / crash (albeit not a compilation error).
I can confirm the isolation of this bug in devices like the iPad 9th, the iPhone 11 and the iPhone SE 2020 all with the Apple A13 Bionic but it happens too with the Apple A12 Bionic because happened too with iPhone XS Max...
I am trying to replicate the steps proposed by @dfilimon as temporary fix, but I am doing something wrong because it keeps crashing with the same error
In particular, commenting out the MatMul invokation at
# elif target.kind.name == "metal":
# try:
# return MetalMatmul().apply(func, target, _)
# except: # pylint: disable=bare-except
# pass
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file intrin_rule_metal.cc
* \brief Metal intrinsic rules.
*/
#include <tvm/tir/op_attr_types.h>
#include "../intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
using tir::FLowerIntrinsic;
struct MetalWarpIntrinsic {
const Op operator()(DataType t, const Op& orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tir.metal.simd_shuffle");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tir.metal.simd_shuffle_up");
} else {
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tir.metal.simd_shuffle_down");
}
}
};
template <typename T>
static PrimExpr DispatchMetalShuffle(const PrimExpr& e) {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> metal_args{{call->args[1], call->args[2]}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), metal_args);
}
TVM_REGISTER_OP("tir.clz").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.floor")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.ceil")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.trunc")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.fabs")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.round")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.nearbyint")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.exp10")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.log2")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.log10")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.tanh")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchNumericalStableTanh);
TVM_REGISTER_OP("tir.sqrt")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.popcount")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.fmod")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.sinh")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.cosh")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchFastErf);
//TVM_REGISTER_OP("tir.tvm_warp_shuffle")
// .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);
//
//TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
// .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);
//
//TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
// .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchMetalShuffle<MetalWarpIntrinsic>);
//
//// Register low-level builtin ops.
//TVM_REGISTER_OP("tir.metal.simd_shuffle")
// .set_num_inputs(2)
// .add_argument("var", "Expr", "The variable to sync.")
// .add_argument("lane", "Expr", "The source thread id.")
// .set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle")
// .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
//
//TVM_REGISTER_OP("tir.metal.simd_shuffle_up")
// .set_num_inputs(2)
// .add_argument("var", "Expr", "The variable to sync.")
// .add_argument("delta", "Expr", "The source lane id offset to be added.")
// .set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle_up")
// .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
//
//TVM_REGISTER_OP("tir.metal.simd_shuffle_down")
// .set_num_inputs(2)
// .add_argument("var", "Expr", "The variable to sync.")
// .add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
// .set_attr<TGlobalSymbol>("TGlobalSymbol", "simd_shuffle_down")
// .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));
} // namespace intrin
} // namespace codegen
} // namespace tvm
after editing both files in terminal I run:
MLC_JIT_POLICY=REDO python -m mlc_llm package
everything compiles fine as usual but still crash with same error... what am I doing wrong?
🐛 Bug Crash on iPads and iPhones with less than 4 gb of ram
A couple of weeks ago I could run mlc chat iOS app on my iPad with 3 gb of ram, but now I can't, It crashes when I try to run stablelm-2-zephyr-1_6b-q4f16_1
To Reproduce
Steps to reproduce the behavior:
Compiler failed with XPC_ERROR_CONNECTION_INTERRUPTED Compiler failed with XPC_ERROR_CONNECTION_INTERRUPTED Compiler failed with XPC_ERROR_CONNECTION_INTERRUPTED MTLCompiler: Compilation failed with XPC_ERROR_CONNECTION_INTERRUPTED on 3 try libc++abi: terminating due to uncaught exception of type tvm::runtime::InternalError: [18:26:08] /Users/pablobutron/Developer/Ghub/MLC/3rdparty/tvm/src/runtime/metal/metal_module.mm:130: InternalError: Check failed: (state != nil) is false: cannot get state: for function stablelm_q4f16_1_8a58367b5e830e69bdbe1141fd05dfb8_fused_dequantize1_fused_NT_matmul5_add2_kernel_2Compiler encountered an internal error Stack trace: [bt] (0) 1 MLC 0x00000001008d4d18 tvm::runtime::detail::LogFatal::Entry::Finalize() + 100 [bt] (1) 2 MLC 0x00000001008d4cb4 tvm::runtime::detail::LogFatal::Entry::Finalize() + 0 [bt] (2) 3 MLC 0x00000001008d39c0 std::1::unique_ptr<std::__1::basic_string<char, std::1::char_traits, std::1::allocator>, std:: 1::default_delete<std::1::basic_string<char, std::__1::char_traits, std:: 1::allocator>>>::~unique_ptr[abi:ne180100]() + 0
[bt] (3) 4 MLC 0x0000000100baa86c tvm::runtime::MetalModuleNode::GetPipelineState(unsigned long, std::1::basic_string<char, std::__1::char_traits, std:: 1::allocator> const&) + 1568
[bt] (4) 5 MLC 0x0000000100ba95d0 tvm::runtime::MetalWrappedFunc::Init(tvm::runtime::MetalModuleNode*, tvm::runtime::ObjectPtr, std::1::basic_string<char, std::__1::char_traits, std:: 1::allocator> const&, unsigned long, unsigned long, std::1::vector<std::1::basic_string<char, std::1::char_traits, std::1::allocator>, std:: 1::allocator<std::__1::basic_string<char, std:: 1::char_traits, std::__1::allocator>>> const&) + 236
[bt] (5) 6 MLC 0x0000000100ba767c tvm::runtime::MetalModuleNode::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr const&) + 676
[bt] (6) 7 MLC 0x0000000100ab924c tvm::runtime::ModuleNode::GetFunction(tvm::runtime::String const&, bool) + 104
[bt] (7) 8 MLC 0x0000000100ab9c5c tvm::runtime::ModuleNode::GetFuncFromEnv(tvm::runtime::String const&) + 268
[bt] (8) 9 MLC 0x0000000100a65020 TVMBackendGetFuncFromEnv + 384
Expected behavior
Works fine as expected
Environment
conda
, source): Condapip
, source): pippython -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))"
, applicable if you compile models):Additional context
the bug can not be reproduced on devices with 6gb+ ram