MARIO-Math-Reasoning / Super_MARIO

MIT License
254 stars 16 forks source link

How to train an instruct model? #24

Open jt4n opened 3 weeks ago

jt4n commented 3 weeks ago

Hi, I’d like to ask for some advice on training the instruct model.

In your code, you used vanilla template to train a base model (deepseek-math-7b-base), so there's no need to apply the role-playing chat template in data preprocessing.

If we need to train an instruct model, e.g. llama3-8b-instruct, we need to apply the llama3 template. So we need to encode the user_message and assistant_message in oneturn, and the "Observation" part should be placed in user_message.

I tried to modified the preprocess_value_dataset to support this feature, but the model response seems to repeat continuously.

def preprocess_value_dataset_instruct(
    examples: Dict[str, List[Any]],
    tokenizer: "PreTrainedTokenizer",
    template: "Template",
    data_args: "DataArguments",
) -> Dict[str, List[List[int]]]:
    # build inputs with format `<bos> X ` and labels with format `X <eos> `
    model_inputs = {"input_ids": [], "attention_mask": [], "Q": [], "labels": []}

    for i in range(len(examples["prompt"])):

        if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:  # prompt only one, response only one
            continue

        # encode the 1st turn, get the 1st assistant message
        first_user_message = examples["prompt"][i]
        multistep_response = json.loads(examples["response"][i][0]['content'])

        first_response = multistep_response[0]
        first_response_content = first_response['step'].strip() + "\n"
        first_assistant_content, next_user_content = split_observation(first_response_content)
        first_assistant_message = [{"role": 'assistant', 'content': first_assistant_content.strip() + "\n"}]
        message = first_user_message + first_assistant_message
        # print(f"first mesasge, <{message}>")

        first_prompt_ids, first_response_ids = template.encode_oneturn(tokenizer, message, examples["system"][i], examples["tools"][i])

        first_response_ids = first_response_ids[:-1]    # discard eos_token_id

        input_ids = first_prompt_ids + first_response_ids

        first_Q = float(first_response['Q'])
        if data_args.train_on_prompt:
            print("train_on_prompt")
            source_mask = input_ids
            Q = [IGNORE_INDEX] * (len(input_ids) - 1)
            Q += [first_Q]
        else:
            source_mask = [IGNORE_INDEX] * len(first_prompt_ids) + first_response_ids
            Q = [IGNORE_INDEX] * (len(input_ids) - 1)
            Q += [first_Q]

        # need to add eos_token_id at the end of sub turn
        source_mask += [tokenizer.eos_token_id]
        input_ids += [tokenizer.eos_token_id]
        Q += [IGNORE_INDEX]

        labels = source_mask
        response_state = multistep_response[-1]['Q']  # last Q

        # if the conversation is more than 1 round
        if len(multistep_response) > 1 and next_user_content:
            for sub_response in multistep_response[1:]:
                if len(sub_response['step']) == 0:
                    print(sub_response['step'])

                sub_user_content = next_user_content
                sub_response_content = sub_response['step'].strip() + "\n"
                sub_assistant_content, next_user_content = split_observation(sub_response_content)
                sub_message = [{"role": 'user', 'content': sub_user_content}] + [{"role": 'assistant', 'content': sub_assistant_content.strip() + "\n"}]
                sub_Q = float(sub_response['Q'])
                sub_prompt_ids, sub_response_ids = template.encode_oneturn(tokenizer, sub_message, examples["system"][i], examples["tools"][i])

                # remove the <|begin_of_text|> in sub turn message
                sub_prompt_ids = sub_prompt_ids[1:]

                sub_response_ids = sub_response_ids[:-1]  # discard the 1000001
                # to make sure the sentence ends with \n instead of <eos>
                # our value model predicts the v based on '\n'

                # print(f"sub_message, <{sub_message}>")

                input_ids += (sub_prompt_ids + sub_response_ids)
                Q += [IGNORE_INDEX] * (len(sub_prompt_ids) + len(sub_response_ids) - 1) + [sub_Q]
                labels += (sub_prompt_ids + sub_response_ids)

                # need to add eos_token_id at the end of sub turn
                input_ids += [tokenizer.eos_token_id]
                Q += [IGNORE_INDEX]
                labels += [tokenizer.eos_token_id]

                if len(input_ids) > data_args.cutoff_len:
                    break

        if template.efficient_eos:  # vanilla template will go into
            input_ids += [tokenizer.eos_token_id]
            Q += [IGNORE_INDEX]
            labels += [tokenizer.eos_token_id]

        if len(input_ids) > data_args.cutoff_len:
            input_ids = input_ids[:data_args.cutoff_len]
            Q = Q[:data_args.cutoff_len]
            labels = labels[:data_args.cutoff_len]

        model_inputs["input_ids"].append(input_ids)
        model_inputs["attention_mask"].append([1] * len(input_ids))
        model_inputs["Q"].append(Q)

        if response_state == -1:
            model_inputs["labels"].append([IGNORE_INDEX] * len(labels))
        elif response_state == 1:
            model_inputs["labels"].append(labels)
        else:
            assert False, response_state

    return model_inputs
Chen-GX commented 2 weeks ago

Maybe you should check the training sequence after your preprocess_value_dataset_instruct.

jt4n commented 1 week ago

We tried and can not fix the problem. So we use the same data and training script to train Llama3-8B-Base. But we found that the sft model output all Q value as -1.0.