mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
18.82k stars 1.54k forks source link

[Bug] mlc_llm.serve server mode Error when multiple(>=4) concurrent requests #2386

Closed ita9naiwa closed 4 months ago

ita9naiwa commented 4 months ago

🐛 Bug

When mlc_llm serve is running as server mode and it receives more than 4 queries, it shows following error

Exception in thread Thread-1:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/home/bc-user/.local/lib/python3.10/site-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/workspace/mlc-llm/cpp/serve/threaded_engine.cc", line 168, in mlc::llm::serve::ThreadedEngineImpl::RunBackgroundLoop()
  File "/workspace/mlc-llm/cpp/serve/engine.cc", line 358, in mlc::llm::serve::EngineImpl::Step()
  File "/workspace/mlc-llm/cpp/serve/engine_actions/new_request_prefill.cc", line 116, in mlc::llm::serve::NewRequestPrefillActionObj::Step(mlc::llm::serve::EngineState)
  File "/workspace/mlc-llm/cpp/serve/model.cc", line 230, in mlc::llm::serve::ModelImpl::BatchPrefill(tvm::runtime::ObjectRef const&, std::vector<long, std::allocator<long> > const&, std::vector<int, std::allocator<int> > const&)
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: mlc::llm::serve::ThreadedEngineImpl::RunBackgroundLoop()
        at /workspace/mlc-llm/cpp/serve/threaded_engine.cc:168
  6: mlc::llm::serve::EngineImpl::Step()
        at /workspace/mlc-llm/cpp/serve/engine.cc:358
  5: mlc::llm::serve::NewRequestPrefillActionObj::Step(mlc::llm::serve::EngineState)
        at /workspace/mlc-llm/cpp/serve/engine_actions/new_request_prefill.cc:116
  4: mlc::llm::serve::ModelImpl::BatchPrefill(tvm::runtime::ObjectRef const&, std::vector<long, std::allocator<long> > const&, std::vector<int, std::allocator<int> > const&)
        at /workspace/mlc-llm/cpp/serve/model.cc:230
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (tvm::runtime::relax_vm::KVState, tvm::runtime::ShapeTuple const&, tvm::runtime::ShapeTuple const&)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, tvm::runtime::ShapeTuple const&, tvm::runtime::ShapeTuple const&, void>(void (tvm::runtime::relax_vm::KVStateObj::*)(tvm::runtime::ShapeTuple const&, tvm::runtime::ShapeTuple const&))::{lambda(tvm::runtime::relax_vm::KVState, tvm::runtime::ShapeTuple const&, tvm::runtime::ShapeTuple const&)#1}>(tvm::runtime::Registry::set_body_method<tvm::runtime::relax_vm::KVState, tvm::runtime::relax_vm::KVStateObj, void, tvm::runtime::ShapeTuple const&, tvm::runtime::ShapeTuple const&, void>(void (tvm::runtime::relax_vm::KVStateObj::*)(tvm::runtime::ShapeTuple const&, tvm::runtime::ShapeTuple const&))::{lambda(tvm::runtime::relax_vm::KVState, tvm::runtime::ShapeTuple const&, tvm::runtime::ShapeTuple const&)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::BeginForward(tvm::runtime::ShapeTuple const&, tvm::runtime::ShapeTuple const&)
  1: tvm::runtime::relax_vm::PagedAttentionKVCacheObj::ReserveAppendLengthInSeq(tvm::runtime::relax_vm::Sequence*, long)
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/src/runtime/relax_vm/paged_kv_cache.cc", line 1448
TVMError: Check failed: block.external_ref_cnt == 0 (1 vs. 0) : The block is 1-time referenced by other blocks, thus cannot accept new KV values.

To Reproduce

Steps to reproduce the behavior:

  1. Run mlc_llm serve CUDA_VISIBLE_DEVICES=1 mlc_llm serve --mode server \ --model-lib llama-7b/llama-7b-cuda.so \ llama-7b
  2. send more than 4 concurrent queries.
    curl -X POST   -H "Content-Type: application/json"   -d '{
        "model": "llama-7b",
        "prompt": "asdadsd",
        "top_p": 0.9,
        "top_k": 1,
        "max_tokens": 128,
        "temperature": 1.0
    }'   http://127.0.0.1:8000/v1/completions

Expected behavior

Environment

Additional context

I compiled llama-2-7B via tvm using following scripts

mlc_llm convert_weight  --output llama-7b --quantization q0f16 \
$LLAMA_7B_DIR
mlc_llm compile ./llama-7b/mlc-chat-config.json --device cuda \
--quantization q0f16 \
--output llama-7b/llama-7b-cuda.so

and run mlc llm via CUDA_VISIBLE_DEVICES=1 mlc_llm serve --mode server \ --model-lib llama-7b/llama-7b-cuda.so \ llama-7b

ita9naiwa commented 4 months ago

It works well when there are less than 4 concurrent requests

MasterJH5574 commented 4 months ago

Hi @ita9naiwa thanks for reporting! Could you try to update your local to the latest commit? Commit https://github.com/apache/tvm/commit/18a2a250f8c7f16f5f5be6753861ba5db8fb89fa can address this issue, if your tvm is behind this commit.

ita9naiwa commented 4 months ago

Sure! I'll try

ita9naiwa commented 4 months ago

@MasterJH5574 Hi, I tested with the latest tvm and it works well.

Thanks!