Open nghidinhit opened 9 months ago
You can just change the "vicuna" to "zephyr". The bigger issue is the rest of preprocess where you need to rewrite the hardcoded mask generation.
I wrote something like the following to support dolphin's format:
if dolphin_format:
turns = conversation.split(sep)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == "":
break
turn_len = len(tokenizer(turn).input_ids)
if turn.split("\n")[0] != conv.roles[1]:
target[cur_len: cur_len + turn_len + 2] = IGNORE_TOKEN_ID
elif turn.split("\n")[0] == conv.roles[1]:
size_of_role = len(tokenizer(conv.roles[1]).input_ids)
target[cur_len: cur_len + size_of_role] = IGNORE_TOKEN_ID
target[cur_len + turn_len: cur_len+turn_len+2] = IGNORE_TOKEN_ID
cur_len += turn_len + 2
target[cur_len:] = IGNORE_TOKEN_ID
If I have some time later, I'll push a PR that should make finetuning agnostic to formatting. In the meantime, feel free to just switch to dolphin's format.
This is my custom code: ` def preprocess( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = get_conversation_template("zephyr")
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
conv.system_message = source[0]["value"].strip()
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.roles[1] + "\n"
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split("<|user|>\n")
cur_len = 1 # for <s> special character
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == "":
break
if i == 0: # system message
parts = [turn, ""]
else:
turn = f"<|user|>\n{turn}"
parts = turn.split(sep) # user text and assistant text
if len(parts) != 2:
break
parts[0] += sep
turn_len = len(tokenizer(turn).input_ids) - 1 # loại bỏ kí tự <s>
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
rank0_print(tokenizer.decode(z))
exit()
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
rank0_print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" #turn = {len(turns) - 1}. (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
`
How can I perform fine-tuning on FastChat using the Zephyr format? I've noticed that within the preprocess function, there is hardcoded logic intended for fine-tuning with the Vicuna template.