unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.78k stars 1.32k forks source link

Documentation for train_on_responses_only? #823

Open rwl4 opened 4 months ago

rwl4 commented 4 months ago

Can you write up some documentation how properly to use the new train_on_responses_only functionality? It doesn't seem to work out of the box with either chat templates or any of the manual formatting (e.g. Alpaca) examples.

danielhanchen commented 4 months ago

Oh yep great idea! https://github.com/unslothai/unsloth/wiki#train-on-completions--responses-only-do-not-train-on-inputs shows approx how to call it, but in the Ollama notebook https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing, you first need to use apply_chat_template which will make it work ie

chat_template = """Below are some instructions that describe some tasks. Write responses that appropriately complete each request.

### Instruction:
{INPUT}

### Response:
{OUTPUT}"""

from unsloth import apply_chat_template
dataset = apply_chat_template(
    dataset,
    tokenizer = tokenizer,
    chat_template = chat_template,
    # default_system_message = "You are a helpful assistant", << [OPTIONAL]
)

Then use

from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),
    ...
    args = TrainingArguments(
        ...
    ),
)
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(trainer)

But in general, the function accepts an instruction and a response text field:

def train_on_responses_only(
    trainer,
    instruction_part = None, <<< eg "Instruction:\n"
    response_part    = None, <<< eg "Response:\n"
):
William-Wildridge commented 2 months ago

How does

instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n", 
response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n",

work when potentially multiple instruction headers could be presented, eg a response to a function call "from: ipython"?

Would something like this be necessary?

instruction_part = "<|start_header_id|>user|ipython<|end_header_id|>\n\n",