mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
17.72k stars 1.41k forks source link

mlc_llm serve fails on concurrent users - Llama3 70B parameter hosting #2462

Open swamysrivathsan opened 1 month ago

swamysrivathsan commented 1 month ago

🐛 Bug

I'm serving Llama3 70B model in a g5.12xlarge aws EC2 instance. Below is the version of mlc_llm installed.

mlc-ai-nightly-cu122 0.15.dev389 mlc-llm-nightly-cu122 0.1.dev1320

I'm using mlc_llm serve to host the model. When performing load test on the endpoint with locust, i see the mlc_llm serve process crashes for more than 5 concurrent users.

To Reproduce

Steps to reproduce the behavior:

  1. Convert the model weights mlc_llm convert_weight with q4f16_1 quantization
  2. Generate configuration with model_type llama and conv-template llama-3, tensor-parallel-shard 4
  3. Compile Model libraries with q4f16_1 quantization
  4. Serve the model with mlc_llm serve

Traceback of error: INFO: - "POST /v1/chat/completions HTTP/1.1" 200 OK INFO: - "POST /v1/chat/completions HTTP/1.1" 200 OK INFO: - "POST /v1/chat/completions HTTP/1.1" 200 OK terminate called after throwing an instance of 'tvm::runtime::InternalError' what(): [11:48:54] /workspace/tvm/src/runtime/relax_vm/paged_kv_cache.cc:891: Check failed: global_blockpool[block_idx].external_ref_cnt == 0 (1 vs. 0) : The sequence is currently referenced by other sequence and thus cannot be removed. Stack trace: 0: _ZN3tvm7runtime6deta 1: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::RemoveSequence(long) 2: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::PopN(long, int) 3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::runtime::relax_vm::KVState, long, int)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, long, int, void>(void (tvm::runtime::relax_vm::KVStateObj::)(long, int))::{lambda(tvm::runtime::relax_vm::KVState, long, int)#1}>(tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, long, int, void>(void (tvm::runtime::relax_vm::KVStateObj::)(long, int))::{lambda(tvm::runtime::relax_vm::KVState, long, int)#1}, std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 4: tvm::runtime::DiscoWorker::Impl::CallPacked(tvm::runtime::DiscoWorker, long, tvm::runtime::PackedFunc, tvm::runtime::TVMArgs const&) 5: tvm::runtime::DiscoWorker::Impl::MainLoop(tvm::runtime::DiscoWorker*) 6: 0x00007f4ca631a252 7: 0x00007f4ca94c4ac2 8: 0x00007f4ca9556a3f 9: 0xffffffffffffffff

Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/usr/local/lib/python3.11/dist-packages/mlc_llm/cli/worker.py", line 51, in main() File "/usr/local/lib/python3.11/dist-packages/mlc_llm/cli/worker.py", line 46, in main worker_func(worker_id, num_workers, reader, writer) File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.call File "tvm/_ffi/_cython/./packed_func.pxi", line 277, in tvm._ffi._cy3.core.FuncCall File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL File "/usr/local/lib/python3.11/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error raise py_err tvm._ffi.baseTraceback (most recent call last): .TVMError: Traceback (most recent call last): 7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (int, int, long, long)>::AssignTypedLambda<void ()(int, int, long, long)>(void ()(int, int, long, long), std::cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 6: tvm::runtime::WorkerProcess(int, int, long, long) 5: tvm::runtime::DiscoWorker::Impl::MainLoop(tvm::runtime::DiscoWorker) 4: tvm::runtime::DiscoWorker::Impl::CallPacked(tvm::runtime::DiscoWorker, long, tvm::runtime::PackedFunc, tvm::runtime::TVMArgs const&) 3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::runtime::relax_vm::KVState, long, int)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, long, int, void>(void (tvm::runtime::relax_vm::KVStateObj::)(long, int))::{lambda(tvm::runtime::relax_vm::KVState, long, int)#1}>(tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, long, int, void>(void (tvm::runtime::relax_vm::KVStateObj::)(long, int))::{lambda(tvm::runtime::relax_vm::KVState, long, int)#1}, std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 2: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::PopN(long, int) 1: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::RemoveSequence(long) 0: _ZN3tvm7runtime6deta File "/workspace/tvm/src/runtime/relax_vm/paged_kv_cache.cc", line 891 TVMError: Check failed: global_blockpool[block_idx].external_ref_cnt == 0 (1 vs. 0) : The sequence is currently referenced by other sequence and thus cannot be removed. File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/usr/local/lib/python3.11/dist-packages/mlc_llm/cli/worker.py", line 51, in main() File "/usr/local/lib/python3.11/dist-packages/mlc_llm/cli/worker.py", line 46, in main worker_func(worker_id, num_workers, reader, writer) File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.call File "tvm/_ffi/_cython/./packed_func.pxi", line 277, in tvm._ffi._cy3.core.FuncCall Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/usr/local/lib/python3.11/dist-packages/mlc_llm/cli/worker.py", line 51, in File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL File "/usr/local/lib/python3.11/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error main() File "/usr/local/lib/python3.11/dist-packages/mlc_llm/cli/worker.py", line 46, in main raise py_err tvm._ffi.base worker_func(worker_id, num_workers, reader, writer). TVMError: Traceback (most recent call last): 7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (int, int, long, long)>::AssignTypedLambda<void ()(int, int, long, long)>(void ()(int, int, long, long), std::cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 6: tvm::runtime::WorkerProcess(int, int, long, long) 5: tvm::runtime::DiscoWorker::Impl::MainLoop(tvm::runtime::DiscoWorker) 4: tvm::runtime::DiscoWorker::Impl::CallPacked(tvm::runtime::DiscoWorker, long, tvm::runtime::PackedFunc, tvm::runtime::TVMArgs const&) 3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::runtime::relax_vm::KVState, long, int)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, long, int, void>(void (tvm::runtime::relax_vm::KVStateObj::)(long, int))::{lambda(tvm::runtime::relax_vm::KVState, long, int)#1}>(tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, long, int, void>(void (tvm::runtime::relax_vm::KVStateObj::)(long, int))::{lambda(tvm::runtime::relax_vm::KVState, long, int)#1}, std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 2: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::PopN(long, int) 1: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::RemoveSequence(long) 0: _ZN3tvm7runtime6deta File "/workspace/tvm/src/runtime/relax_vm/paged_kv_cache.cc", line 891 TVMError: Check failed: global_blockpool[block_idx].external_ref_cnt == 0 (1 vs. 0) : The sequence is currently referenced by other sequence and thus cannot be removed. File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.call File "tvm/_ffi/_cython/./packed_func.pxi", line 277, in tvm._ffi._cy3.core.FuncCall File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL File "/usr/local/lib/python3.11/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error raise py_err tvm._ffi.base.TVMError: Traceback (most recent call last): 7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (int, int, long, long)>::AssignTypedLambda<void ()(int, int, long, long)>(void ()(int, int, long, long), std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 6: tvm::runtime::WorkerProcess(int, int, long, long) 5: tvm::runtime::DiscoWorker::Impl::MainLoop(tvm::runtime::DiscoWorker) 4: tvm::runtime::DiscoWorker::Impl::CallPacked(tvm::runtime::DiscoWorker, long, tvm::runtime::PackedFunc, tvm::runtime::TVMArgs const&) 3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::runtime::relax_vm::KVState, long, int)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, long, int, void>(void (tvm::runtime::relax_vm::KVStateObj::)(long, int))::{lambda(tvm::runtime::relax_vm::KVState, long, int)#1}>(tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, long, int, void>(void (tvm::runtime::relax_vm::KVStateObj::)(long, int))::{lambda(tvm::runtime::relax_vm::KVState, long, int)#1}, std::__cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 2: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::PopN(long, int) 1: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::RemoveSequence(long) 0: _ZN3tvm7runtime6deta File "/workspace/tvm/src/runtime/relax_vm/paged_kv_cache.cc", line 891 TVMError: Check failed: global_blockpool[block_idx].external_ref_cnt == 0 (1 vs. 0) : The sequence is currently referenced by other sequence and thus cannot be removed.

Expected behavior

Environment

Additional context

MasterJH5574 commented 1 month ago

Thank you @swamysrivathsan for reporting this issue. While we will look into it, would you mind sharing the requests you sent that can reproduce the error? This would be helpful for us to triage this issue.

MasterJH5574 commented 1 month ago

Hi @swamysrivathsan, we fixed this recently. Could you please follow https://llm.mlc.ai/docs/install/mlc_llm.html#option-1-prebuilt-package to upgrade your pip package and try again?

swamysrivathsan commented 1 month ago

Hello @MasterJH5574 - I tried but I'm still facing the issue when the number of concurrent requests to the model is increased to 5.

Below is the version of mlc package I have.

mlc-ai-nightly-cu122 0.15.dev404 mlc-llm-nightly-cu122 0.1.dev1352

I use tensor parallel size as 4 to host the model. Below are the commands I used to convert the weights and compile the model.

  1. python -m mlc_llm convert_weight source/Meta-Llama-3-70B-Instruct/ --quantization q4f16_1 --model-type llama --device cuda --source source/Meta-Llama-3-70B-Instruct/ --output mlc-model/Meta-Llama-3-70B-Instruct-q4f16_1-MLC/

  2. python -m mlc_llm gen_config source/Meta-Llama-3-70B-Instruct/ --quantization q4f16_1 --model-type llama --conv-template llama-3 --tensor-parallel-shards 4 --output mlc-model/Meta-Llama-3-70B-Instruct-q4f16_1-MLC/

  3. python -m mlc_llm compile mlc-model/Meta-Llama-3-70B-Instruct-q4f16_1-MLC/ --quantization q4f16_1 --model-type llama --output mlc-model/Meta-Llama-3-70B-Instruct-q4f16_1-MLC/Meta-Llama-3-70B-Instruct-q4f16_1-MLC.so

  4. python -m mlc_llm serve Meta-Llama-3-70B-Instruct-q4f16_1-MLC --model-lib Meta-Llama-3-70B-Instruct-q4f16_1-MLC/Meta-Llama-3-70B-Instruct-q4f16_1-MLC.so --mode server --host 0.0.0.0 --port 8000

Following is the sample request sent to the model

@task def gpt(self): prompt = random.choice(prompts_examples) modelId = "Meta-Llama-3-70B-Instruct-q4f16_1-MLC"

    payload = {
        "model": modelId,
        "messages": [{"role": "user", "content": prompt}],
        "stream": False,
        }
    res = self.client.post("/v1/chat/completions",json =payload)
    response = json.loads(res.text)
    print(response['choices'][0]['message']['content'])

client is a HttpSession object from Locust

Below is the traceback of the error. Also the error handling is not seamless as it brings down the whole process and the process restart + weights loading takes quite some time.

INFO: 10.137.116.28:40965 - "POST /v1/chat/completions HTTP/1.1" 200 OK terminate called after throwing an instance of 'tvm::runtime::InternalError' what(): [07:07:38] /workspace/tvm/src/runtime/disco/nccl/nccl.cc:95: ncclErrror: unhandled cuda error (run with NCCL_DEBUG=INFO for details) Stack trace: 0: _ZN3tvm7runtime6deta 1: tvm::runtime::nccl::AllReduce(tvm::runtime::NDArray, tvm::runtime::ReduceKind, tvm::runtime::NDArray) 2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::runtime::NDArray, int, tvm::runtime::NDArray)>::AssignTypedLambda<tvm::runtime::nccl::mk_TVM3::{lambda(tvm::runtime::NDArray, int, tvm::runtime::NDArray)#1}>(tvm::runtime::nccl::__mk_TVM3::{lambda(tvm::runtime::NDArray, int, tvm::runtime::NDArray)#1}, std::cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, std::cxx11::basic_string<char, std::char_traits, std::allocator >, tvm::runtime::TVMRetValue) 3: tvm::runtime::AllReduce(tvm::runtime::NDArray, tvm::runtime::ReduceKind, tvm::runtime::NDArray) 4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::runtime::NDArray, tvm::runtime::ShapeTuple, tvm::runtime::NDArray)>::AssignTypedLambda<tvm::runtime::mk_TVM2::{lambda(tvm::runtime::NDArray, tvm::runtime::ShapeTuple, tvm::runtime::NDArray)#1}>(tvm::runtime::__mk_TVM2::{lambda(tvm::runtime::NDArray, tvm::runtime::ShapeTuple, tvm::runtime::NDArray)#1}, std::cxx11::basic_string<char, std::char_traits, std::allocator >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, std::cxx11::basic_string<char, std::char_traits, std::allocator >, tvm::runtime::TVMRetValue) 5: tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame, tvm::runtime::relax_vm::Instruction) 6: tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop() 7: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator > const&) 8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::GetClosureInternal(tvm::runtime::String const&, bool)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 9: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 10: tvm::runtime::relax_vm::CUDAGraphExtensionNode::RunOrCapture(tvm::runtime::relax_vm::VirtualMachine, tvm::runtime::ObjectRef const&, tvm::runtime::ObjectRef, long, tvm::runtime::Optional) 11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::mk_TVM0::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::relax_vm::mk_TVM0, tvm::runtime::TVMRetValue) 12: tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame, tvm::runtime::relax_vm::Instruction) 13: tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop() 14: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator > const&) 15: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::GetClosureInternal(tvm::runtime::String const&, bool)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue)#1}> >::Call(tvm::runtime::PackedFuncObj const, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 16: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue) 17: tvm::runtime::DiscoWorker::Impl::CallPacked(tvm::runtime::DiscoWorker, long, tvm::runtime::PackedFunc, tvm::runtime::TVMArgs const&) 18: tvm::runtime::DiscoWorker::Impl::MainLoop(tvm::runtime::DiscoWorker*) 19: 0x00007f3c894ba252 20: 0x00007f3c8c664ac2 21: 0x00007f3c8c6f6a3f 22: 0xffffffffffffffff

tqchen commented 4 weeks ago

There was a related issue about unhandled cuda error in https://github.com/pytorch/pytorch/issues/11756 which suggest it might related to driver/hw not sure what was in this case. Maybe you can turn on the nccl debug info env as suggested to see if there are more info here