turingmotors / heron

Apache License 2.0
165 stars 25 forks source link

Add Template Prompt #41

Closed Onely7 closed 7 months ago

Onely7 commented 7 months ago

【Add Template Prompt】

Instruction Tuning 済みのモデルのテンプレートプロンプトをいい感じに指定できるようにする

Instruction tuning済みのモデル(Instructやchatがついているモデル)は固有のプロンプトテンプレートを持っている。 これらのプロンプトテンプレートは、SFT (+ DPO) の際に使用されている。 このため、このプロンプトテンプレートを用いなかったり、意図しない他のテンプレートを使用してしまうことで、パフォーマンスが低下する可能性がある。 そこで、heron を使用した学習の際にも、使用している Instruction Tuning 済み LLM に適したプロンプトテンプレートを使用して学習することが望ましいと考えられる。

そこで、プロンプトテンプレートを選択して heron を使用して学習できるようにコードを変更することを試みる。

プロンプトテンプレートの例

llama2の場合

f'''[INST] <<SYS>>
{system_prompt}
<</SYS>>
{prompt}[/INST]

'''

mistralの場合

question = "<s>[INST] What is your favourite condiment? [/INST]"
answer = "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!</s> "
next_question = "[INST] Do you have mayonnaise recipes? [/INST]"

それぞれの公式のテンプレートは 🤗 hugging face でモデルがアップロードされた Repository の tokenizer_config.json の chat_template に入っている。 このテンプレートの使い方については 🤗 Templates for Chat Models を参照していただきたい。

また、tokenizer.apply_chat_template() にメッセージリストを渡すことで、tokenizer モデルに基づくテンプレートをもとにプロンプトを作成することができる。 メッセージは、役割”role” (system、user、assistantなど) とテキストのペアになる。

HuggingFace Transformers の チャットモデルテンプレート を試す

# 入力
chat = [
{"role": "user", "content": "Who is the cutest in Madoka Magica?"},
]
tokenizer.use_default_system_prompt = False
tokenizer.apply_chat_template(chat, tokenize=False)

# 出力
<s> [INST] Who is the cutest in Madoka Magica? [/INST]

(デフォルトのシステムメッセージを有効化したい場合は、tokenizer.use_default_system_prompt = True を指定してください)

実装には、^ のものを用いれば簡単だと思うが、instruction tuning の際には、一般的にシステムの応答のみのlossを計算するというマスク設計を実施しなければならないため、今回これを直接使うのは難しい…。

したがって、heron では、(応答のみのlossを計算するように)自分でプロンプトテンプレートを作成する必要があった。

heron のプロンプトテンプレート実装の詳細

現状は個別にDatasetを用意して選ぶようにしている。
現状の個別に実装しているものはこちらです。

これらをdatasetのconfigから指定して選択できるようにしたい。

heron/datasets/train_instruction_template.py 実装の重要な部分のみ抽出

# heron/datasets/train_instruction_template.py
def llama2_instruction(agent, tokenizer, is_system_message):
    if is_system_message:
        if agent == "gpt":
            agent_prompt = ""
            next_agent_prompt = f"{tokenizer.eos_token}"
        elif agent == "human":
            system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\'t know the answer to a question, please don\'t share false information.\n"
            agent_prompt = f"[INST] <<SYS>>\n{system_prompt}<</SYS>>\n\n"
            next_agent_prompt = " [/INST] "
    else:
        if agent == "gpt":
            agent_prompt = ""
            next_agent_prompt = f"{tokenizer.eos_token}"
        elif agent == "human":
            agent_prompt = "[INST] "
            next_agent_prompt = " [/INST] "
    return agent_prompt, next_agent_prompt

# ...

def add_train_instruction_template(agent, tokenizer, instruction_template_type, is_system_message):
    if instruction_template_type == "llama2":
        agent_prompt, next_agent_prompt = llama2_instruction(agent, tokenizer, is_system_message)
        return agent_prompt, next_agent_prompt
    elif instruction_template_type in ("mistral", "mixtral"):
        agent_prompt, next_agent_prompt = mistral_instruction(agent, tokenizer)
        return agent_prompt, next_agent_prompt
    elif instruction_template_type == "command-r":
        agent_prompt, next_agent_prompt = commandr_instruction(agent, tokenizer)
        return agent_prompt, next_agent_prompt
    elif instruction_template_type == "tinyllama":
        agent_prompt, next_agent_prompt = tinyllama_instruction(agent, tokenizer, is_system_message)
        return agent_prompt, next_agent_prompt
    elif instruction_template_type == "none":
        agent_prompt, next_agent_prompt = none_instruction(agent, tokenizer)
        return agent_prompt, next_agent_prompt
    else:
        agent_prompt, next_agent_prompt = base_instruction(agent, tokenizer)
        return agent_prompt, next_agent_prompt

このheron/datasets/train_instruction_template.pyheron/datasets/llava_instruct_datasets.py のようなデータセットファイルから呼び出して、^ の add_train_instruction_template 関数を使用することで、訓練データにシステムの応答のみの loss を計算するように、プロンプトテンプレートを付与することができる。

# heron/datasets/llava_instruct_datasets.py の add_train_instruction_template 関数呼び出し箇所

def _get_item_train(self, index):
    row = self.loaded_dataset[index]

    # ...

    # create prompt by instruction_template_type
    agent_prompt, next_agent_prompt = add_train_instruction_template(
        agent, 
        self.processor.tokenizer, 
        self.instruction_template_type, 
        self.is_system_message,
        )
    # ...

また、評価データに関しても、プロンプトテンプレートを付与したいので、heron/datasets/inference_instruction_template.py を作成して、heron/datasets/llava_instruct_datasets.py のようなデータセットファイルから呼び出して、^ の add_inference_instruction_template 関数を使用する。(但し、評価データに関しては、訓練そのもの(loss計算)を実施しないので、訓練データと比較するとシンプルなコードになる)

heron/datasets/inference_instruction_template.py 実装の重要な部分のみ抽出

# heron/datasets/inference_instruction_template.py

def llama2_instruction(content, is_system_message):
    if is_system_message:
        system_prompt = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\'t know the answer to a question, please don\'t share false information.\n"
        prompt = f"[INST] <<SYS>>\n{system_prompt}<</SYS>>\n\n{content} [/INST] "
        return prompt
    else:
        prompt = f"[INST] {content} [/INST] "
        return prompt

# ...

def add_inference_instruction_template(content, tokenizer, instruction_template_type, is_system_message):
    if instruction_template_type == "llama2":
        prompt = llama2_instruction(content, is_system_message)
        return prompt
    elif instruction_template_type in ("mistral", "mixtral"):
        prompt = mistral_instruction(content)
        return prompt
    elif instruction_template_type == "command-r":
        prompt = commandr_instruction(content)
        return prompt
    elif instruction_template_type == "tinyllama":
        prompt = tinyllama_instruction(content, tokenizer, is_system_message)
        return prompt
    elif instruction_template_type == "none":
        prompt = none_instruction(content)
        return prompt
    else:
        prompt = base_instruction(content)
        return prompt
# heron/datasets/llava_instruct_datasets.py の add_inference_instruction_template 関数呼び出し箇所

def _get_item_inference(self, index):
    row = self.loaded_dataset[index]

    # ...

    # create prompt by instruction_template_type
    prompt = add_inference_instruction_template(
            row['conversations'][language],
            self.processor.tokenizer,
            self.instruction_template_type,
            self.is_system_message,            
            )
    # ...

heron学習実行時のテンプレート指定方法

configs/datasets 以下の yaml ファイルに各テンプレート方式

を記述することで、プロンプトテンプレートを指定できるようにした。

[!WARNING] 🚨 但し、heron/datasets/llava_instruct_datasets.py 以外のファイルでは、(未実装のため)プロンプトテンプレートの適用がされないので注意が必要です

例)configs/datasets/llava_en_instruct.yaml の場合

# Before
dataset_type: llava_instruct
dataset_root: ./
jsonl_path:
language: "en"
n_train: 157000
n_val: 712

# After
dataset_type: llava_instruct
dataset_root: ./
jsonl_path:
language: "en"
n_train: 157000
n_val: 712
instruction_template_type: llama2
system_message: false

実装方針についてはより良いものがあればぜひ教えて下さい! よろしくお願いします!