nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

Fix prompt format mismatch with huggingface #807

Closed bhbruce closed 3 months ago

bhbruce commented 3 months ago
  1. System prompt: remove \; token [1] is generated by default.

  2. End of System prompt:

    Before:         -> After
    \n\n               \n\n"""
    => Origin code implies three \n.
    """
  3. Fix append_user_prompt & append_bot_prompt to match behavior of tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) Correct Format for LLama2:

    
    <s>[INST] <<SYS>>
    {{ system_prompt }}
    <</SYS>>

{{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST]

bhbruce commented 3 months ago

Example code to get transformer prompt:

import torch
from transformers import pipeline
pipe = pipeline("text-generation", model="meta-llama/Llama-2-7b-chat-hf",
                torch_dtype=torch.float32,
                device_map="auto",
                token="xxxxxx")
# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
messages = [
    {
        "role": "system",
        "content": "Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.",
    },
    {"role": "user", "content": "Who's the president of the USA?"}]

prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"Prompt:\n{prompt}")
tk_prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"Tokenized prompt:\n{tk_prompt}")

print("\n\n==== Multi-run ====")

messages = [
    {
        "role": "system",
        "content": "Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.",
    },
    {"role": "user", "content": "Who's the president of the USA?"},
    {"role": "assistant", "content": "The president of the United States is currently Joe Biden."},
    {"role": "user", "content": "How are you doing?"},
    ]

prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"Prompt:\n{prompt}")
tk_prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"Tokenized prompt:\n{tk_prompt}")

Output

Prompt:
<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Who's the president of the USA? [/INST]
Tokenized prompt:
<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Who's the president of the USA? [/INST]

==== Multi-run ====
Prompt:
<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

Who's the president of the USA? [/INST] The president of the United States is currently Joe Biden. </s><s>[INST] How are you doing? [/INST]
Tokenized prompt:
<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>