Open hamelsmu opened 9 months ago
Ok but what is a Prompter
? That seems to be a thing that is constructing the user_prompt
, we can see from the above code that the
AlpacaPrompter.build_prompt()
method is being called. The source code for AlpacaPrompter
looks like this
class AlpacaPrompter(Prompter):
"""
Base class for alpaca prompters
"""
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
system_format: str = "{system}"
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
self.match_prompt_style()
def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.system_format = "{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def _build_result(self, instruction, input_text, output):
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
if input_text:
res = (
self.system_format.format(system=self.system_prompt)
if self.system_prompt
else ""
) + self.turn_format.format(instruction=instruction, input=input_text)
else:
res = (
self.system_format.format(system=self.system_no_input_prompt)
if self.system_no_input_prompt
else ""
) + self.turn_no_input_format.format(instruction=instruction)
if output:
res = f"{res}{output}"
return res
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
yield self._build_result(instruction, input, output)
I'm thinking I can skip most of this complexity because I'm just assembling strings, as long as build_prompt
returns the right thing
ok but how are these fields being hydrated exactly?
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["instruction"],
prompt["input"] if "input" in prompt else "",
prompt["output"],
)
This seems to be coming from the data itself, see the README
alpaca
: instruction; input(optional)
{"instruction": "...", "input": "...", "output": "..."}
For freeform stuff, we really don't care about instruction just input
and output
But with the alpaca format, its not clear how to do multi-turn conversations. This is where the sharegpt
seems to come in. From the README:
sharegpt
: conversations where from
is human
/gpt
. (optional: first row with role system
to override default system prompt)
{"conversations": [{"from": "...", "value": "..."}]}
In my axolotl config for the Honeycomb model I fine-tuned I have the following config:
datasets:
- path: _synth_data/alpaca_synth_queries_healed.jsonl
type: sharegpt
conversation: alpaca # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
So Multi-turn conversations are probably a very different codepath 😢 ....
For multi-turn
the magic happens in prompt_tokenizers.ShareGPTPromptTokenizingStrategy.tokenize_prompt
You will notice some key differences compared to the tokenize_prompt
method above.
user
, assistant
and null roles.
user
role is an input
role for the purpose of ignoring inputs (that is why the labels get set to IGNORE_TOKEN_ID
for the user roleassistant
role we can see that the the "role component" of the prompt is ignored in the labels. The idea is that the assistant content is basically an output.input
(and is ignored).add_eos_token
and strip_bos_token
arguments.
add_eos_token=False
and strip_bos_token=True
because this isn ot the beginning or endadd_eos_token=True
and strip_bos_token=True
because this is usually the end. However, there is an exception with chatml
in the code that I think we can ignore here.add_eos_token=False
, strip_bos_token=False
since it is the beginning.
parse_tokenized_to_result
# prompt_tokenizers.ShareGPTPromptTokenizingStrategy.tokenize_prompt
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
"""
Tokenizing strategy for ShareGPT prompts.
"""
def get_conversation_thread(self, prompt):
return prompt["conversations"]
def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation.
result, current_len = tokenize_prompt_default()
conversation: Conversation = (
self.prompter._conversation.copy() # pylint: disable=protected-access
)
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
role_remap = []
if (
conversation.name == "vicuna_v1.1"
and "roles" in prompt
and len(prompt["roles"]) >= 2
):
role_remap = [
{"from": conversation.roles[0], "to": prompt["roles"][0]},
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]
try:
for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if not isinstance(part, tuple):
LOG.warning(f"expected tuple, got {part}")
continue
user, assistant = conversation.roles
role, content = part
# Uses "in" because role contains extra characters
if user in role:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else role
)
turn = role + content
# this is still the user query, we should
if not content.strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if self.train_on_inputs:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant in role:
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else role
)
turn = role + content
# this should be the assistant response, should end with an eos token
if not content.strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
add_eos_token = not (
conversation.name == "chatml"
and conversation.sep == self.tokenizer.eos_token
)
res = self._tokenize(
turn,
add_eos_token=add_eos_token,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
labels = copy.deepcopy(res["input_ids"])
if not self.train_on_inputs:
# mask out role tokens from the labels
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif role == "":
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
if self.train_on_inputs:
labels = copy.deepcopy(res["input_ids"])
else:
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
The parse_tokenized_to_result
function that concatenates the conversation is this:
prompt_tokenizers.parse_tokenized_to_result
def parse_tokenized_to_result(
result: Dict[str, List[int]],
current_len: int,
res: Dict[str, List[int]],
labels: List[int],
pad_token_id: Union[int, None] = None,
) -> Tuple[Dict[str, List[int]], int]:
"""
Parses the tokenized prompt and append the tokenized input_ids, attention_mask and labels to the result
"""
input_ids = res["input_ids"]
input_len = len(input_ids)
result["input_ids"][current_len : current_len + input_len] = input_ids
result["attention_mask"][current_len : current_len + input_len] = [
1 if x != pad_token_id else 0 for x in input_ids
]
result["labels"][current_len : current_len + input_len] = labels
current_len += input_len
return result, current_len
But what is the entrypoint to loading datasets in axolotl? This starts iin cli.__init__.load_datasets
:
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, tokenizer
)
...
The magic happens in these two: utils.data.prepare_dataset
and utils.models.load_tokenizer
.
In utils.data.prepare_dataset
, this cals load_prepare_datasets
, which calls load_tokenized_prepared_datasets
which calls utils.get_dataset_wrapper
def prepare_dataset(cfg, tokenizer):
...
if cfg.test_datasets:
train_dataset, _, prompters = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH, split="train"
)
....
## this calls `load_prepare_datasets`
def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,
default_dataset_prepared_path,
split="train",
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path, split=split
)
## the above calls huggingface `datasets.load.load_dataset` and then later `utils.data.get_dataset_wrapper`
def load_tokenized_prepared_datasets(
tokenizer,
cfg,
default_dataset_prepared_path,
split="train",
):
...
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type, # this the type from the axolotl config like `sharegpt` for Honeycomb
d_prompt_style=d_prompt_style,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
Next, the utils.data.get_dataset_wrapper
has this code where it calls TokenizedPromptDataset
:
#utils.data.get_dataset_wrapper
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
ds_strategy,
dataset,
**ds_kwargs,
)
The ds_strategy
is the thing that applies the prompt construction. The call to load(config_dataset.type, ...)
resolves to axolotl.prompt_strategies.sharegpt.SimpleShareGPTPromptTokenizingStrategy
which inherits from ShareGPTPromptTokenizingStrategy
which we covered in an earlier comment.
We can see that this strategy is is then read into the TokenizedPromptDataset
datasets.TokenizedPromptDataset
class TokenizedPromptDataset(Dataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
process_count (int): Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
"""
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
self.process_count = process_count
self.keep_in_memory = keep_in_memory
super().__init__(
self.process(dataset).data,
**kwargs,
)
def process(self, dataset):
features = dataset.features.keys()
num_proc = min(64, self.process_count if self.process_count else os.cpu_count())
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",
**map_kwargs,
)
The thing we want to pay attention to here is self.prompt_tokenizer.tokenize_prompt,
which we covered above. So now we know how the tokenize_prompt
gets applied when process
is eventually called.
The other important thing that gets called in cli.__init__.load_datasets
is loading the tokenizer, which is done in utils.models.load_tokenizer
: special tokens from the cfg are loaded, as well as exceptions for padding for different model classes:
# Mistral's official FA implementation requires left padding
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
tokenizer.padding_side = "left"
# Qwen base only has single token, so we need to set the special tokens
if cfg.is_qwen_derived_model:
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
for attr_name in token_ids:
...
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
I'm using what I learned from this journey to review this PR
Axolotl Notes For Prompt Construction
the below function
parse_instruction_fields
generates a tuple(instruction,input,response)
The key to making a new input format is to make sure you can parse your input into these parts and deal with them.Note that the
instruction
andinput
are part of theuser_prompt
which is effectively the the "inputs" for the purposes oftrain_on_inputs: false
which works by setting the appropriate labels to a label id to be ignored intokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
prompt_tokenizers.InstructionPromptTokenizingStrategy
In Axolotl, many PromptStrategies override the
parse_instruct_fields
method like this:prompt_tokenizers.AlpacaPromptTokenizingStrategy
The above is called from
prompt_strategies.alpaca_chat