lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.
Apache License 2.0
36.77k stars 4.53k forks source link

Finetune fastchat with Zephyr format #2918

Open nghidinhit opened 9 months ago

nghidinhit commented 9 months ago

How can I perform fine-tuning on FastChat using the Zephyr format? I've noticed that within the preprocess function, there is hardcoded logic intended for fine-tuning with the Vicuna template.

image

jwong8314 commented 9 months ago

You can just change the "vicuna" to "zephyr". The bigger issue is the rest of preprocess where you need to rewrite the hardcoded mask generation.

I wrote something like the following to support dolphin's format:

 if dolphin_format:
            turns = conversation.split(sep)

            cur_len = 1
            target[:cur_len] = IGNORE_TOKEN_ID
            for i, turn in enumerate(turns):
                if turn == "":
                    break
                turn_len = len(tokenizer(turn).input_ids)           

                if turn.split("\n")[0] != conv.roles[1]:
                    target[cur_len: cur_len + turn_len + 2] = IGNORE_TOKEN_ID
                elif  turn.split("\n")[0] == conv.roles[1]:
                    size_of_role = len(tokenizer(conv.roles[1]).input_ids)
                    target[cur_len: cur_len + size_of_role]  = IGNORE_TOKEN_ID
                    target[cur_len + turn_len: cur_len+turn_len+2]  = IGNORE_TOKEN_ID

                cur_len += turn_len + 2
            target[cur_len:] = IGNORE_TOKEN_ID

If I have some time later, I'll push a PR that should make finetuning agnostic to formatting. In the meantime, feel free to just switch to dolphin's format.

nghidinhit commented 9 months ago

This is my custom code: ` def preprocess( sources, tokenizer: transformers.PreTrainedTokenizer, ) -> Dict: conv = get_conversation_template("zephyr")

roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
    conv.system_message = source[0]["value"].strip()
    source = source[1:]

    conv.messages = []
    for j, sentence in enumerate(source):
        role = roles[sentence["from"]]
        assert role == conv.roles[j % 2], f"{i}"
        conv.append_message(role, sentence["value"])
    conversations.append(conv.get_prompt())

# Tokenize conversations
input_ids = tokenizer(
    conversations,
    return_tensors="pt",
    padding="max_length",
    max_length=tokenizer.model_max_length,
    truncation=True,
).input_ids

targets = input_ids.clone()

# Mask targets. Only compute loss on the assistant outputs.
sep = conv.roles[1] + "\n"
for conversation, target in zip(conversations, targets):
    total_len = int(target.ne(tokenizer.pad_token_id).sum())

    turns = conversation.split("<|user|>\n")
    cur_len = 1  # for <s> special character
    target[:cur_len] = IGNORE_TOKEN_ID

    for i, turn in enumerate(turns):
        if turn == "":
            break

        if i == 0:  # system message
            parts = [turn, ""]
        else:
            turn = f"<|user|>\n{turn}"
            parts = turn.split(sep)  # user text and assistant text
            if len(parts) != 2:
                break
            parts[0] += sep
        turn_len = len(tokenizer(turn).input_ids) - 1  # loại bỏ kí tự <s>
        instruction_len = len(tokenizer(parts[0]).input_ids) - 1
        target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID

        cur_len += turn_len

    target[cur_len:] = IGNORE_TOKEN_ID

    if False:  # Inspect and check the correctness of masking
        z = target.clone()
        z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
        rank0_print(tokenizer.decode(z))
        exit()

    if cur_len < tokenizer.model_max_length:
        if cur_len != total_len:
            target[:] = IGNORE_TOKEN_ID
            rank0_print(
                f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                f" #turn = {len(turns) - 1}. (ignored)"
            )

return dict(
    input_ids=input_ids,
    labels=targets,
    attention_mask=input_ids.ne(tokenizer.pad_token_id),
)

`