mlc-ai / mlc-llm

Universal LLM Deployment Engine with ML Compilation
https://llm.mlc.ai/
Apache License 2.0
19.27k stars 1.59k forks source link

[Question] TVMError: Unknown conversation template #1796

Closed tlopex closed 9 months ago

tlopex commented 9 months ago

❓ General Questions

Due to the #1755 , I ignored the specific chat template for Baichuan model. However, when I added the template just like #1701 (To File cpp/conv_template.cc and python/mlc_chat/interface/gen_config.py) It shows that TVMError: Unknown conversation template: baichuan-2 I know it is the info of Conversation Conversation::FromTemplate of cpp/conv_template.cc but how I just don't know why it would happen and how can I make it work? Thanks.

Here is the Traceback

Traceback (most recent call last):
  File "/home/tlopex/.local/bin/mlc_chat", line 11, in <module>
    load_entry_point('mlc-chat', 'console_scripts', 'mlc_chat')()
  File "/home/tlopex/mlc-llm/python/mlc_chat/__main__.py", line 36, in main
    cli.main(sys.argv[2:])
  File "/home/tlopex/mlc-llm/python/mlc_chat/cli/chat.py", line 41, in main
    chat(
  File "/home/tlopex/mlc-llm/python/mlc_chat/interface/chat.py", line 133, in chat
    cm = ChatModule(model, device, chat_config=config, model_lib_path=model_lib_path)
  File "/home/tlopex/mlc-llm/python/mlc_chat/chat_module.py", line 780, in __init__
    self._reload(self.model_lib_path, self.model_path, user_chat_config_json_str)
  File "/home/tlopex/mlc-llm/python/mlc_chat/chat_module.py", line 1008, in _reload
    self._reload_func(lib, model_path, app_config_json, kv_cache_config.asjson())
  File "/home/tlopex/relax/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/tlopex/relax/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/home/tlopex/mlc-llm/cpp/llm_chat.cc", line 1615, in mlc::llm::LLMChatModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
    chat_->Reload(args[0], args[1], args[2], args[3]);
  File "/home/tlopex/mlc-llm/cpp/llm_chat.cc", line 580, in mlc::llm::LLMChat::Reload(tvm::runtime::TVMArgValue, tvm::runtime::String, tvm::runtime::String, tvm::runtime::String)
    model_config = LoadJSONOverride(config_str, false);
  File "/home/tlopex/mlc-llm/cpp/llm_chat.cc", line 557, in mlc::llm::LLMChat::LoadJSONOverride(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, bool)
    LoadJSONOverride(config_json, partial_update);
  File "/home/tlopex/mlc-llm/cpp/llm_chat.cc", line 530, in mlc::llm::LLMChat::LoadJSONOverride(picojson::value const&, bool)
    this->conversation_ = Conversation::FromTemplate(conv_template);
  File "/home/tlopex/mlc-llm/cpp/conv_templates.cc", line 743, in mlc::llm::Conversation::FromTemplate(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
    {"guanaco", Guanaco},
tvm._ffi.base.TVMError: Traceback (most recent call last):
  4: mlc::llm::LLMChatModule::GetFunction(tvm::runtime::String const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/tlopex/mlc-llm/cpp/llm_chat.cc:1615
  3: mlc::llm::LLMChat::Reload(tvm::runtime::TVMArgValue, tvm::runtime::String, tvm::runtime::String, tvm::runtime::String)
        at /home/tlopex/mlc-llm/cpp/llm_chat.cc:580
  2: mlc::llm::LLMChat::LoadJSONOverride(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, bool)
        at /home/tlopex/mlc-llm/cpp/llm_chat.cc:557
  1: mlc::llm::LLMChat::LoadJSONOverride(picojson::value const&, bool)
        at /home/tlopex/mlc-llm/cpp/llm_chat.cc:530
  0: mlc::llm::Conversation::FromTemplate(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
        at /home/tlopex/mlc-llm/cpp/conv_templates.cc:743
  File "/home/tlopex/mlc-llm/cpp/conv_templates.cc", line 743
TVMError: Unknown conversation template: baichuan-2

Here is my modified template

Conversation Baichuan2() {
  Conversation conv;
  conv.name = "baichuan-2";
  conv.system = "";
  conv.roles = {"<reserved_106>", "<reserved_107>"};
  conv.messages = {};
  conv.offset = 0;
  conv.separator_style = SeparatorStyle::kSepRoleMsg;
  conv.seps = {""};
  conv.role_msg_sep = "";
  conv.role_empty_sep = "";
  // TODO(mlc-team): add eos to mlc-chat-config
  // and remove eos from stop token setting.
  conv.stop_tokens = {195};
  conv.stop_str = "<reserved_106>";
  conv.add_bos = false;
  return conv;
}

}  // namespace

using ConvFactory = Conversation (*)();

Conversation Conversation::FromTemplate(const std::string& name) {
  static std::unordered_map<std::string, ConvFactory> factory = {
      {"chatml", ChatML},
      {"llama_default", LlamaDefault},
      {"llama-2", Llama2},
      {"mistral_default", MistralDefault},
      {"open_hermes_mistral", OpenHermesMistral},
      {"neural_hermes_mistral", NeuralHermesMistral},
      {"codellama_completion", CodeLlamaCompletion},
      {"codellama_instruct", CodeLlamaInstruct},
      {"gpt2", GPT2},
      {"vicuna_v1.1", VicunaV11},
      {"conv_one_shot", ConvOneShot},
      {"redpajama_chat", RedPajamaChat},
      {"rwkv_world", RWKVWorld},
      {"rwkv", RWKV},
      {"gorilla", Gorilla},
      {"guanaco", Guanaco},
      {"dolly", Dolly},
      {"oasst", Oasst},
      {"stablelm", StableLM},
      {"stablecode_completion", StableCodeCompletion},
      {"stablecode_instruct", StableCodeInstruct},
      {"minigpt", MiniGPT},
      {"moss", MOSS},
      {"LM", VanillaLM},
      {"stablelm-3b", StableLM3B},
      {"gpt_bigcode", GPTBigCode},
      {"wizardlm_7b", WizardLM7B},
      {"wizard_coder_or_math", WizardCoderOrMATH},
      {"glm", GLM},
      {"phi-2", Phi2},
      {"qwen", ChatML},
      {"stablelm-2", StableLM2},
      {"baichuan-2", Baichuan2},
      };
  auto it = factory.find(name);
  if (it == factory.end()) {
    LOG(FATAL) << "Unknown conversation template: " << name;
  }
  return it->second();
}
rickzx commented 9 months ago

Your code looks correct to me. Have you recompiled the mlc c++ library after making the change to cpp/conv_template.cc?

tlopex commented 9 months ago

@rickzx Thanks. I recomplied that and solved the problem.