THUDM / ReST-MCTS

ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search (NeurIPS 2024)
263 stars 16 forks source link

ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search

📃 [ReST-MCTS*] [GitHub] [Website]

We develop a reinforced self-training approach, called ReST-MCTS**, based on integrating process reward guidance with tree search MCTS for collecting higher-quality reasoning traces as well as per-step value to train policy and reward models. ReST-MCTS* circumvents the per-step manual annotation typically used to train process rewards by tree-search-based reinforcement learning: Given oracle final correct answers, *ReST-MCTS is able to infer the correct process rewards by estimating the probability this step can help lead to the correct answer. These inferred rewards serve dual purposes: they act as value targets for further refining the process reward model and also facilitate the selection of high-quality traces for policy model self-training.

Table of Contents

Key Differences

Getting Started

Prepare Env

Considering the different dependency versions of transformers for Mistral (or Llama) and SciGLM, you should install different environments through miniconda and install corresponding required packages by:

running Mistral (or Llama)

pip install -r requirements_mistral.txt

or running SciGLM

pip install -r requirements_sciglm.txt

Note that for some models on huggingface like the GLM series, you may need to install specific versions of transformers.

Model Implementation

*MCTS Search**

To run MCTS* search, you should implement a policy as well as a process reward model (value model). You can directly set these models by providing the model paths in the file models/model.py, substituting INFERENCE_MODEL_DIR, VALUE_BASE_MODEL_DIR and VALUE_MODEL_STATE_DICT.

Policy Model

INFERENCE_MODEL_DIR is the local path to the policy model, model could be Llama3-8b-Instruct, Mistral-7B: MetaMATH, and SciGLM-6B.

Process Reward Model

VALUE_BASE_MODEL_DIR is the local path to the value model. Considering the different dependency versions of transformers, Mistral-7B is adopted as the backbone of the value model when the policy model is Llama3-8B-Instruct or MetaMATH: Mistral-7B. When the policy model is SciGLM, we use ChatGLM3-6B as the backbone of the value model.

Aiming to gather value train data for science, we integrate questions of a lean science dataset $D_{sci}$ within [SciInstruct] to construct $D_{V_0}$. This dataset consists of 11,554 questions, where each question is paired with a correct step-by-step solution. (See Section 4.1 of the paper for more details.)

You can download [$D_{V_0}$] and put them in PRM/data to train Mistral-7B as the initial process reward model and obtain VALUE_MODEL_STATE_DICT. We also provide PRM/train_VM_chatglm.py and PRM/train_VM_mistral.py.

Model Setting

We now only provide the implementation of the llama, glm and mistral as policy, with glm and mistral as value model in models/model.py. If you are trying with other models, you can refer to our implementation and modify relevant codes to implement the corresponding models. Once you've implemented the policy and value model, you should modify the LOCAL_INFERENCE_IDX and LOCAL_VALUE_IDX in models/model.py to the corresponding model index.

Data Preparation

Before running search for evaluation or generation, you have to make sure your target question dataset is in the correct format. The data file should be a json file with items in the following format:

{
  "content": "Calculate the sum of the first 10 prime numbers.",
  "answer": "129"
}

The content entry is required, serving as the question. While the answer entry is optional, it is used for evaluation.

Run MCTS* Search

The implementation of MCTS search can be found in MCTS. We provide a search interface in MCTS/task.py. To run MCTS search for a single question, you can refer to the following script:

from MCTS.task import *
question = "Calculate the sum of the first 10 prime numbers."
task = MCTS_Task(question, 'llama', 'local', lang='en')
output = task.run()
print(output['solution'])

For evaluation of MCTS* on benchmarks, you can refer to evalaute.py, setting the parameter --mode to "mcts". You should specify the benchmark name and the exact file (subset) you want to evaluate. A simple demonstration is provided below:

python evaluate.py \
  --task_name "scibench" \
  --file "thermo" \
  --propose_method "gpt" \
  --value_method "local" \
  --mode "mcts" \
  --evaluate "scibench" \
  --iteration_limit 50 \
  --use_reflection "simple" \
  --branch 3

You can also refer to the MCTS/args.md for more details on the search parameters.

Data & Model (take Llama3-8B-Instruct as an example)

Given question set $D_G$, we use Llama3-8B-Instruct guided by MCTS* to generate synthetic data for policy model and value model. (See Algorithm 1 of the paper for more details.)

Download policy data (positive samples) for training 1st policy model (Llama3-8b-Instruct): [Hugging Face]

Download PRM data (positive and negative samples) for training 1st reward model (Mistral-7B: MetaMATH): [Hugging Face]

Download the trained policy model: [Hugging Face]

Self-training

For our methods:

Regarding Llama3-8B-Instruct and Mistral-7B: MetaMATH, we use the default repo of [MAmmoTH] to train the policy model and evaluate.

Regarding SciGLM-6B, we use the default repo of [SciGLM] to train the policy model and evaluate.

We also implement self-rewarding as our baseline in ./self_train/self_train_dpo.py.

Leaderboard

Self-training Results:

Accuracy of Different Verifiers:

Accuracy of Different Searches (we also provide the plot code in figures/plot_math_self_training.py):

Citation

If you find our work helpful, please kindly cite our paper:

@article{zhang2024rest,
  title={ReST-MCTS*: LLM Self-Training via Process Reward Guided Tree Search},
  author={Zhang, Dan and Zhoubian, Sining and Hu, Ziniu and Yue, Yisong and Dong, Yuxiao and Tang, Jie},
  journal={arXiv preprint arXiv:2406.03816},
  year={2024}
}