unslothai / unsloth

Finetune Llama 3.1, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
15.75k stars 1.07k forks source link

How to enable training on competition only along with packing = True #109

Open akjindal53244 opened 8 months ago

akjindal53244 commented 8 months ago

Hi unsloth team,

I am wondering how to enable packing = True when I need to only train on output tokens for a `<input, output>' text pair eg: '<question, answer>'. This is a general use-case for instruction fine-tuning problems where fine-tuning is performed on output token only while keeping input/instruction fixed. In this case, supporting packing greatly improves fine-tuning speed.

P.S.: This functionality is supported in axolotl :)

danielhanchen commented 8 months ago

@akjindal53244 Ye it seems like Huggingface's SFTrainer only supports DataCollatorForCompletionOnlyLM for packing = False. I'll ask Younes from HF about this :) A solution would be to write a custom data collator which can solve this issue

eabdullin commented 3 months ago

@akjindal53244 You could create your own DataCollator. Here I created one, but for other purpose. I needed to ignore code outputs, but in your case, you have to be able to determine the end of sequence and response template when Packing enabled.

# we don't want train on the code ouputs, so let's ignore them
class DataCollatorForCompletionAndIgnoredCodeOutputs(DataCollatorForCompletionOnlyLM):
    def __init__(self, output_start_template: str, output_end_template: str, **kwargs,):
        super().__init__(**kwargs)
        self.output_start_template_token_ids = self.tokenizer.encode(output_start_template, add_special_tokens=False)
        self.output_end_template_token_ids = self.tokenizer.encode(output_end_template, add_special_tokens=False)

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        batch = super().torch_call(examples)
        for i in range(len(examples)):
            response_token_ids_start_idx = None
            start_ix = None
            for idx in range(len(batch["labels"][i])):
                if batch["labels"][i][idx : idx + len(self.output_start_template_token_ids)].tolist() == self.output_start_template_token_ids:
                    start_ix = idx
                if start_ix is not None and start_ix != idx and batch["labels"][i][idx : idx + len(self.output_end_template_token_ids)].tolist() == self.output_end_template_token_ids:
                    batch["labels"][i, start_ix+len(self.output_start_template_token_ids):idx] = self.ignore_index
                    start_ix = None
        return batch

data_collator = DataCollatorForCompletionAndIgnoredCodeOutputs(
    output_start_template = '```output',
    output_end_template = '```',
    response_template = '\nAssistant:\n',
    tokenizer=tokenizer
)
danielhanchen commented 3 months ago

@eabdullin Would you be contribute this into Unsloth? :) Super appreciate it :)