hamelsmu / hamel-site

Repo for Hamel's Professional Website
http://hamel.dev/
34 stars 26 forks source link

Axolotl Prompt Construction Notes #11

Open hamelsmu opened 4 months ago

hamelsmu commented 4 months ago

Axolotl Notes For Prompt Construction

the below function parse_instruction_fields generates a tuple (instruction,input,response) The key to making a new input format is to make sure you can parse your input into these parts and deal with them.

Note that the instruction and input are part of the user_prompt which is effectively the the "inputs" for the purposes of train_on_inputs: false which works by setting the appropriate labels to a label id to be ignored in

tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len

prompt_tokenizers.InstructionPromptTokenizingStrategy

class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
    """
    Tokenizing strategy for instruction-based prompts.
    """

    def parse_instruction_fields(
        self, prompt
    ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
        raise NotImplementedError

    def tokenize_prompt(self, prompt):
        (
            instruction,
            input,  # pylint: disable=redefined-builtin
            response,
        ) = self.parse_instruction_fields(prompt)
        user_prompt = next(
            iter(
                self.prompter.build_prompt(
                    instruction,
                    input,
                )
            )
        )
        tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
        if not self.train_on_inputs:
            user_prompt_len = len(tokenized_prompt["input_ids"])
            # TODO this could be sped up using numpy array slicing
            tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
        tokenized_res_prompt = self._tokenize(
            response, strip_bos_token=True, add_eos_token=True
        )
        tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
        tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
        tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]

In Axolotl, many PromptStrategies override the parse_instruct_fields method like this:

prompt_tokenizers.AlpacaPromptTokenizingStrategy

class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for Alpaca prompts.
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["instruction"],
            prompt["input"] if "input" in prompt else "",
            prompt["output"],
        )

The above is called from prompt_strategies.alpaca_chat

def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
    prompt_style = PromptStyle.CHAT.value
    if ds_cfg and "conversation" in ds_cfg:
        prompt_style = ds_cfg["conversation"]

    return AlpacaPromptTokenizingStrategy(
        AlpacaPrompter(prompt_style),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )
hamelsmu commented 4 months ago

Ok but what is a Prompter? That seems to be a thing that is constructing the user_prompt, we can see from the above code that the

AlpacaPrompter.build_prompt() method is being called. The source code for AlpacaPrompter looks like this

class AlpacaPrompter(Prompter):
    """
    Base class for alpaca prompters
    """

    system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
    system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
    system_format: str = "{system}"
    turn_format: str
    turn_no_input_format: str
    prompt_style: Optional[PromptStyle] = None

    def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
        self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
        self.match_prompt_style()

    def match_prompt_style(self):
        # pylint: disable=duplicate-code
        if self.prompt_style == PromptStyle.INSTRUCT.value:
            self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
            self.turn_no_input_format = (
                "### Instruction:\n{instruction}\n\n### Response:\n"
            )
            self.system_format = "{system}\n\n"
        if self.prompt_style == PromptStyle.CHAT.value:
            self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
            self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
            self.system_format = "SYSTEM: {system}\n"
        if self.prompt_style == PromptStyle.CHATML.value:
            self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
            self.turn_no_input_format = (
                "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
            )
            self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"

    def _build_result(self, instruction, input_text, output):
        # returns the full prompt from instruction and optional input
        # if a label (=response, =output) is provided, it's also appended.
        if input_text:
            res = (
                self.system_format.format(system=self.system_prompt)
                if self.system_prompt
                else ""
            ) + self.turn_format.format(instruction=instruction, input=input_text)
        else:
            res = (
                self.system_format.format(system=self.system_no_input_prompt)
                if self.system_no_input_prompt
                else ""
            ) + self.turn_no_input_format.format(instruction=instruction)
        if output:
            res = f"{res}{output}"

        return res

    def build_prompt(
        self,
        instruction: str,
        input: Union[None, str] = None,  # pylint: disable=redefined-builtin
        output: Union[None, str] = None,
    ) -> Generator[str, None, None]:
        yield self._build_result(instruction, input, output)

I'm thinking I can skip most of this complexity because I'm just assembling strings, as long as build_prompt returns the right thing

hamelsmu commented 4 months ago

ok but how are these fields being hydrated exactly?

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["instruction"],
            prompt["input"] if "input" in prompt else "",
            prompt["output"],
        )

This seems to be coming from the data itself, see the README

For freeform stuff, we really don't care about instruction just input and output

hamelsmu commented 4 months ago

But with the alpaca format, its not clear how to do multi-turn conversations. This is where the sharegpt seems to come in. From the README:

In my axolotl config for the Honeycomb model I fine-tuned I have the following config:

datasets:
  - path: _synth_data/alpaca_synth_queries_healed.jsonl
    type: sharegpt
    conversation: alpaca  # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py

So Multi-turn conversations are probably a very different codepath 😢 ....

hamelsmu commented 4 months ago

For multi-turn the magic happens in prompt_tokenizers.ShareGPTPromptTokenizingStrategy.tokenize_prompt

You will notice some key differences compared to the tokenize_prompt method above.

  1. There is a for-loop that iterates through the conversation
  2. There are user, assistant and null roles.
    • The user role is an input role for the purpose of ignoring inputs (that is why the labels get set to IGNORE_TOKEN_ID for the user role
    • For the assistant role we can see that the the "role component" of the prompt is ignored in the labels. The idea is that the assistant content is basically an output.
    • If the role is an empty string or null, then it is treated as a user input (and is ignored).
  3. We have to pay very close attention to the way the tokenizer is called, particularly the add_eos_token and strip_bos_token arguments.
    • user role: add_eos_token=False and strip_bos_token=True because this isn ot the beginning or end
    • assistant role: add_eos_token=True and strip_bos_token=True because this is usually the end. However, there is an exception with chatml in the code that I think we can ignore here.
    • null role: from the comments this is the first turn, so add_eos_token=False, strip_bos_token=False since it is the beginning.
      1. At the end of the for loop, the current turn is appended onto previous turns using parse_tokenized_to_result
        # prompt_tokenizers.ShareGPTPromptTokenizingStrategy.tokenize_prompt
        result, current_len = parse_tokenized_to_result(
        result,
        current_len,
        res,
        labels,
        pad_token_id=self.tokenizer.pad_token_id,
        )
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
    """
    Tokenizing strategy for ShareGPT prompts.
    """

    def get_conversation_thread(self, prompt):
        return prompt["conversations"]

    def tokenize_prompt(self, prompt):
        # Initial values. We will append to these as we go through the conversation.
        result, current_len = tokenize_prompt_default()
        conversation: Conversation = (
            self.prompter._conversation.copy()  # pylint: disable=protected-access
        )

        # support for custom roles from the dataset, only useful for vicuna style prompts/roles
        role_remap = []
        if (
            conversation.name == "vicuna_v1.1"
            and "roles" in prompt
            and len(prompt["roles"]) >= 2
        ):
            role_remap = [
                {"from": conversation.roles[0], "to": prompt["roles"][0]},
                {"from": conversation.roles[1], "to": prompt["roles"][1]},
            ]

        try:
            for _, part in enumerate(
                self.prompter.build_prompt(self.get_conversation_thread(prompt))
            ):
                if not isinstance(part, tuple):
                    LOG.warning(f"expected tuple, got {part}")
                    continue

                user, assistant = conversation.roles
                role, content = part

                # Uses "in" because role contains extra characters
                if user in role:
                    role = (
                        role.replace(role_remap[0]["from"], role_remap[0]["to"])
                        if role_remap
                        else role
                    )
                    turn = role + content
                    # this is still the user query, we should
                    if not content.strip():
                        LOG.warning(f"user turn has empty text: {prompt}")
                    res = self._tokenize(
                        turn,
                        add_eos_token=False,
                        strip_bos_token=True,
                    )
                    if self.train_on_inputs:
                        labels = copy.deepcopy(res["input_ids"])
                    else:
                        # everything from this is masked out from the labels
                        labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
                elif assistant in role:
                    role = (
                        role.replace(role_remap[1]["from"], role_remap[1]["to"])
                        if role_remap
                        else role
                    )
                    turn = role + content
                    # this should be the assistant response, should end with an eos token
                    if not content.strip():
                        LOG.warning(f"assistant turn has empty text: {prompt}")
                    add_eos_token = not (
                        conversation.name == "chatml"
                        and conversation.sep == self.tokenizer.eos_token
                    )
                    res = self._tokenize(
                        turn,
                        add_eos_token=add_eos_token,
                        strip_bos_token=True,
                    )
                    role_res = self._tokenize(
                        role.rstrip(),
                        add_eos_token=False,
                        strip_bos_token=True,
                    )
                    labels = copy.deepcopy(res["input_ids"])
                    if not self.train_on_inputs:
                        # mask out role tokens from the labels
                        len_role = len(role_res["input_ids"])
                        labels[:len_role] = [IGNORE_TOKEN_ID] * min(
                            len_role, len(labels)
                        )
                elif role == "":
                    turn = content
                    # this is only ever the first part, should include the bos token and the user query
                    res = self._tokenize(
                        turn, add_eos_token=False, strip_bos_token=False
                    )
                    if self.train_on_inputs:
                        labels = copy.deepcopy(res["input_ids"])
                    else:
                        # everything from this is masked out from the labels
                        labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
                else:
                    LOG.warning(f"unhandled role: {role}")
                    continue

                # pylint: disable=duplicate-code
                result, current_len = parse_tokenized_to_result(
                    result,
                    current_len,
                    res,
                    labels,
                    pad_token_id=self.tokenizer.pad_token_id,
                )
            return result
        except (KeyError, AssertionError, IndexError) as err:
            raise InvalidDataException(str(err)) from err

The parse_tokenized_to_result function that concatenates the conversation is this:

prompt_tokenizers.parse_tokenized_to_result

def parse_tokenized_to_result(
    result: Dict[str, List[int]],
    current_len: int,
    res: Dict[str, List[int]],
    labels: List[int],
    pad_token_id: Union[int, None] = None,
) -> Tuple[Dict[str, List[int]], int]:
    """
    Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
    """

    input_ids = res["input_ids"]
    input_len = len(input_ids)
    result["input_ids"][current_len : current_len + input_len] = input_ids
    result["attention_mask"][current_len : current_len + input_len] = [
        1 if x != pad_token_id else 0 for x in input_ids
    ]
    result["labels"][current_len : current_len + input_len] = labels
    current_len += input_len

    return result, current_len
hamelsmu commented 4 months ago

But what is the entrypoint to loading datasets in axolotl? This starts iin cli.__init__.load_datasets:

def load_datasets(
    *,
    cfg: DictDefault,
    cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
    tokenizer = load_tokenizer(cfg)

    train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
        cfg, tokenizer
    )

...

The magic happens in these two: utils.data.prepare_dataset and utils.models.load_tokenizer.

In utils.data.prepare_dataset, this cals load_prepare_datasets, which calls load_tokenized_prepared_datasets which calls utils.get_dataset_wrapper

def prepare_dataset(cfg, tokenizer):
   ...
            if cfg.test_datasets:
                train_dataset, _, prompters = load_prepare_datasets(
                    tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
                )
    ....

## this calls `load_prepare_datasets`

def load_prepare_datasets(
    tokenizer: PreTrainedTokenizerBase,
    cfg,
    default_dataset_prepared_path,
    split="train",
) -> Tuple[Dataset, Dataset, List[Prompter]]:
    dataset, prompters = load_tokenized_prepared_datasets(
        tokenizer, cfg, default_dataset_prepared_path, split=split
    )

## the above calls huggingface `datasets.load.load_dataset` and then later `utils.data.get_dataset_wrapper`

def load_tokenized_prepared_datasets(
    tokenizer,
    cfg,
    default_dataset_prepared_path,
    split="train",
):
 ...
  dataset_wrapper, dataset_prompter = get_dataset_wrapper(
                  config_dataset=config_dataset,
                  tokenizer=tokenizer,
                  cfg=cfg,
                  dataset=ds,
                  d_base_type=d_base_type, # this the type from the axolotl config like `sharegpt` for Honeycomb
                  d_prompt_style=d_prompt_style,
              )
              datasets.append(dataset_wrapper)
              prompters.append(dataset_prompter)
hamelsmu commented 4 months ago

Next, the utils.data.get_dataset_wrapper has this code where it calls TokenizedPromptDataset:

#utils.data.get_dataset_wrapper 
    elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
        dataset_prompter = UnsupportedPrompter()
        dataset_wrapper = TokenizedPromptDataset(
            ds_strategy,
            dataset,
            **ds_kwargs,
        )

The ds_strategy is the thing that applies the prompt construction. The call to load(config_dataset.type, ...) resolves to axolotl.prompt_strategies.sharegpt.SimpleShareGPTPromptTokenizingStrategy which inherits from ShareGPTPromptTokenizingStrategy which we covered in an earlier comment.

We can see that this strategy is is then read into the TokenizedPromptDataset

datasets.TokenizedPromptDataset

class TokenizedPromptDataset(Dataset):
    """
    Dataset that returns tokenized prompts from a stream of text files.
        Args:
            prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
            dataset (dataset.Dataset): Dataset with text files.
            process_count (int): Number of processes to use for tokenizing.
            keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
    """

    def __init__(  # pylint: disable=super-init-not-called
        self,
        prompt_tokenizer: PromptTokenizingStrategy,
        dataset: Dataset,
        process_count: Optional[int] = None,
        keep_in_memory: Optional[bool] = False,
        **kwargs,
    ):
        self.prompt_tokenizer = prompt_tokenizer
        self.process_count = process_count
        self.keep_in_memory = keep_in_memory
        super().__init__(
            self.process(dataset).data,
            **kwargs,
        )

    def process(self, dataset):
        features = dataset.features.keys()
        num_proc = min(64, self.process_count if self.process_count else os.cpu_count())

        map_kwargs = {}
        if self.prompt_tokenizer.supports_batched:
            map_kwargs["batched"] = True
            map_kwargs["batch_size"] = 100
        return dataset.map(
            self.prompt_tokenizer.tokenize_prompt,
            num_proc=num_proc,
            remove_columns=features,
            keep_in_memory=self.keep_in_memory,
            desc="Tokenizing Prompts",
            **map_kwargs,
        )

The thing we want to pay attention to here is self.prompt_tokenizer.tokenize_prompt, which we covered above. So now we know how the tokenize_prompt gets applied when process is eventually called.

hamelsmu commented 4 months ago

The other important thing that gets called in cli.__init__.load_datasets is loading the tokenizer, which is done in utils.models.load_tokenizer: special tokens from the cfg are loaded, as well as exceptions for padding for different model classes:

 # Mistral's official FA implementation requires left padding
    if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
        tokenizer.padding_side = "left"

    # Qwen base only has single token, so we need to set the special tokens
    if cfg.is_qwen_derived_model:
        token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
        for attr_name in token_ids:
            ...
    additional_special_tokens = None
    if cfg.special_tokens:
        special_tokens = cfg.special_tokens.to_dict()
        additional_special_tokens = special_tokens.pop(
            "additional_special_tokens", None
        )
hamelsmu commented 4 months ago

I'm using what I learned from this journey to review this PR