mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.08k stars 1.56k forks source link

[Bug] consecutive generate with mistral is failing #1274

Closed varshith15 closed 11 months ago

varshith15 commented 11 months ago

šŸ› Bug

running cm.generate(prompt=text) in a loop for mistral model is giving the following error

InternalError                             Traceback (most recent call last)
Cell In[7], line 3
      1 # text = "[INST] what does lmao mean? [/INST]"
      2 cm.reset_chat()
----> 3 cm.generate(prompt=final_text)

File /usr/local/lib/python3.8/dist-packages/mlc_chat/chat_module.py:856, in ChatModule.generate(self, prompt, generation_config, progress_callback)
    854 for _ in range(num_return_sequences):
    855     self.reset_chat()
--> 856     self._prefill(prompt, generation_config=generation_config)
    858     if not progress_callback:
    859         while not self._stopped():

File /usr/local/lib/python3.8/dist-packages/mlc_chat/chat_module.py:1073, in ChatModule._prefill(self, input, decode_next_token, place_in_prompt, generation_config)
   1070 else:
   1071     input_str = input
-> 1073 self._prefill_func(
   1074     input_str, decode_next_token, place_in_prompt.value, generation_config_str
   1075 )

File tvm/_ffi/_cython/./packed_func.pxi:332, in tvm._ffi._cy3.core.PackedFuncBase.__call__()

File tvm/_ffi/_cython/./packed_func.pxi:277, in tvm._ffi._cy3.core.FuncCall()

File tvm/_ffi/_cython/./base.pxi:182, in tvm._ffi._cy3.core.CHECK_CALL()

File /usr/local/lib/python3.8/dist-packages/tvm/_ffi/base.py:481, in raise_last_ffi_error()
    475 # The exception PyObject may contain a large amount of state,
    476 # including all stack frames that may be inspected in a later
    477 # PDB post-mortem.  Therefore, we must make sure to remove the
    478 # underlying PyObject* from the C++ side after we retrieve it.
    479 _LIB.TVMDropLastPythonError()
--> 481 raise py_err

File /workspace/mlc-llm/cpp/llm_chat.cc:1487, in mlc::llm::LLMChatModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#5}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const()

File /workspace/mlc-llm/cpp/llm_chat.cc:836, in mlc::llm::LLMChat::PrefillStep(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, bool, mlc::llm::PlaceInPrompt, tvm::runtime::String)()

File /workspace/mlc-llm/cpp/llm_chat.cc:1206, in mlc::llm::LLMChat::ForwardTokens(std::vector<int, std::allocator<int> >, long)()

InternalError: Traceback (most recent call last):
  12: mlc::llm::LLMChatModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#5}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /workspace/mlc-llm/cpp/llm_chat.cc:1487
  11: mlc::llm::LLMChat::PrefillStep(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, bool, bool, mlc::llm::PlaceInPrompt, tvm::runtime::String)
        at /workspace/mlc-llm/cpp/llm_chat.cc:836
  10: mlc::llm::LLMChat::ForwardTokens(std::vector<int, std::allocator<int> >, long)
        at /workspace/mlc-llm/cpp/llm_chat.cc:1206
  9: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeClosurePacked(tvm::runtime::ObjectRef const&, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::relax_vm::VirtualMachineImpl::GetClosureInternal(tvm::runtime::String const&, bool)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  7: tvm::runtime::relax_vm::VirtualMachineImpl::InvokeBytecode(long, std::vector<tvm::runtime::TVMRetValue, std::allocator<tvm::runtime::TVMRetValue> > const&)
  6: tvm::runtime::relax_vm::VirtualMachineImpl::RunLoop()
  5: tvm::runtime::relax_vm::VirtualMachineImpl::RunInstrCall(tvm::runtime::relax_vm::VMFrame*, tvm::runtime::relax_vm::Instruction)
  4: _ZN3tvm7runtime13PackedFun
  3: tvm::runtime::TypedPackedFunc<tvm::runtime::relax_vm::AttentionKVCache (tvm::runtime::relax_vm::AttentionKVCache, tvm::runtime::NDArray, long)>::AssignTypedLambda<tvm::runtime::relax_vm::AttentionKVCache (*)(tvm::runtime::relax_vm::AttentionKVCache, tvm::runtime::NDArray, long)>(tvm::runtime::relax_vm::AttentionKVCache (*)(tvm::runtime::relax_vm::AttentionKVCache, tvm::runtime::NDArray, long), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  2: tvm::runtime::relax_vm::AttentionKVCacheWindowOverride(tvm::runtime::relax_vm::AttentionKVCache, tvm::runtime::NDArray, long)
  1: tvm::runtime::relax_vm::AttentionKVCacheObj::WindowOverride(tvm::runtime::NDArray, long)
  0: _ZN3tvm7runtime6deta
  File "/workspace/tvm/src/runtime/relax_vm/lm_support.cc", line 164
InternalError: Check failed: this->fill_count == max_cache_size (2771 vs. 4096) : 

To Reproduce

Steps to reproduce the behavior:

cm = ChatModule(model="mlc-llm/dist/mistral/params/", model_lib_path="mlc-llm/dist/mistral/mistral-cuda.so")
for _ in range(5):
  cm.reset_chat()
  cm.generate(prompt=text)

mlc-chat-config used

{
   "model_lib":"mistral-q4f16_1",
   "local_id":"mistral-q4f16_1",
   "conv_template":"LM",
   "conv_config":{
    "roles": ["", ""],
    "stop_str": "</s>"
   },
   "temperature":0.7,
   "repetition_penalty":1.0,
   "top_p":0.95,
   "mean_gen_len":4096,
   "max_gen_len":16384,
   "num_shards":1,
   "use_presharded_weights":false,
   "shift_fill_factor":0.3,
   "tokenizer_files":[
      "tokenizer.model",
      "added_tokens.json",
      "tokenizer.json"
   ],
   "model_category":"mistral",
   "model_name":"mistral",
   "vocab_size":32001,
   "sliding_window":4096,
   "sliding_window_chunk_size":4096
}

Environment

CharlieFRuan commented 11 months ago

Hi @varshith15, thank you for reporting the issue! I was able to reproduce it on my end. Confirmed that https://github.com/apache/tvm/pull/16132 is able to fix this as there is an issue when we reset the chat currently.

varshith15 commented 11 months ago

@CharlieFRuan the issue persists even without chat reset in the loop, maybe because I use LM conv_template?

CharlieFRuan commented 11 months ago

Oh yes I can reproduce that as well; also able to be fixed by that PR; LM conv_template should be fine. I didn't notice that we call reset_chat() in generate() as well.

The PR should be merged in soon! You could build from source at the meantime if in a hurry.

varshith15 commented 11 months ago

Ah got it, yeah reset_chat() wasn't part of generate() before, Thanks!

CharlieFRuan commented 11 months ago

Hi @varshith15, the fix should be included in https://github.com/mlc-ai/relax now, let me know if there are other issues.

varshith15 commented 11 months ago

@CharlieFRuan it works fine now, thanks!