Open simonw opened 9 months ago
Here's an example template for Mixtral Instruct: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/tokenizer_config.json#L42
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
Decoded with https://observablehq.com/@simonw/display-content-from-a-json-string
{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}
Pretty printed by ChatGPT:
{{ bos_token }}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}
{% if message['role'] == 'user' %}
{{ '[INST] ' + message['content'] + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ message['content'] + eos_token }}
{% else %}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif %}
{% endfor %}
I think bos_token
and eos_token
are defined in that JSON too:
"bos_token": "<s>",
"clean_up_tokenization_spaces": false,
"eos_token": "</s>",
Here's the code that renders that: https://github.com/huggingface/transformers/blob/238d2e3c44366aba9dc5c770c95475765a6725cb/src/transformers/tokenization_utils_base.py#L1760-L1779
@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
if version.parse(jinja2.__version__) <= version.parse("3.0.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.0.0 to be installed. Your version is " f"{jinja2.__version__}."
)
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)
Which is called from here:
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
rendered = compiled_template.render(
messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
)
Not yet sure where self.special_tokens_map
is populated.
The add_generation_prompt
controls if there's the equivalent of Assistant:
added on at the end of the prompt - true for completion models, instruction tuned models tend not to need it.
Useful note about Mistral: https://discord.com/channels/1144547040454508606/1156609509674975262/1184860885748035594
from transformers import AutoTokenizer
from typing import List, Dict
def build_prompt(
messages: List[Dict[str, str]],
tokenizer: AutoTokenizer,
):
prompt = ""
for i, msg in enumerate(messages):
is_user = {"user": True, "assistant": False}[msg["role"]]
assert (i % 2 == 0) == is_user
content = msg["content"]
assert content == content.strip()
if is_user:
prompt += f"[INST] {content} [/INST]"
else:
prompt += f" {content}</s>"
tokens_ids = tokenizer.encode(prompt)
token_str = tokenizer.convert_ids_to_tokens(tokens_ids)
return tokens_ids, token_str
tok = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
messages = [
{"role": "user", "content": "2+2"},
{"role": "assistant", "content": "4!"},
{"role": "user", "content": "+2"},
{"role": "assistant", "content": "6!"},
{"role": "user", "content": "+4"},
]
tokens_ids, token_str = build_prompt(messages, tok)
print(tokens_ids)
# [1, 733, 16289, 28793, 28705, 28750, 28806, 28750, 733, 28748, 16289, 28793, 28705, 28781, 28808, 2, 733, 16289, 28793, 648, 28750, 733, 28748, 16289, 28793, 28705, 28784, 28808, 2, 733, 16289, 28793, 648, 28781, 733, 28748, 16289, 28793]
print(token_str)
# ['<s>', '▁[', 'INST', ']', '▁', '2', '+', '2', '▁[', '/', 'INST', ']', '▁', '4', '!', '</s>', '▁[', 'INST', ']', '▁+', '2', '▁[', '/', 'INST', ']', '▁', '6', '!', '</s>', '▁[', 'INST', ']', '▁+', '4', '▁[', '/', 'INST', ']']
This plugin urgently needs a better solution for handling chat templates, to better support models like Mixtral.
Currently it only supports one, for Llama 2, which is hard-coded like this:
https://github.com/simonw/llm-llama-cpp/blob/dc53ef9b00423f55d11060ce51ba415f31795a08/llm_llama_cpp.py#L220-L235
I think templating is the right way to go here. Rather than invent something new I'd like to reuse this Hugging Face mechanism, which was created back in September as far as I can tell:
https://huggingface.co/docs/transformers/chat_templating
Templates can use Jinja and end up looking something like this: