haozheji / exact-optimization

ICML 2024 - Official Repository for EXO: Towards Efficient Exact Optimization of Language Model Alignment
https://arxiv.org/abs/2402.00856
MIT License
45 stars 0 forks source link
large-language-models learning-from-human-feedback natural-language-processing rlhf

Towards Efficient Exact Optimization of Language Model Alignment

Overview

This is the official pytorch implementation of the EXO algorithm for efficient exact optimization of aligning language models (LMs) with human preferences, as described in our ICML 2024 paper Towards Efficient Exact Optimization of Language Model Alignment.

EXO essentially minimizes the reverse KL between the empirical distributions defined by the policy and the reward. As a comparison, DPO corresponds to minimizing the forward KL. The above figure illustrates the distinct behavior of policies obtained by minimizing (a) the reverse KL and (b) the forward KL.

Dependencies

python >= 3.9.0
transformers >= 4.34.1
deepspeed >= 0.11.2

It is recommended to precomplie the required extensions when installing deepspeed (For more details refer to the installation guideline of deepspeed):

DS_BUILD_FUSED_ADAM=1 DS_BUILD_TRANSFORMER=1 DS_BUILD_TRANSFORMER_INFERENCE=1 pip install deepspeed 

General Guideline

EXO supports the settings of (i) training directly on the preference data and (ii) training with the supervision provided by a learned reward model. The pipeline is comprised of the following stages:

To train on a custom dataset, one should create a dataset class by inheriting from the base class PromptRawDataset located in src/utils/data/raw_datasets.py. Then, simply add a few lines of code to the get_raw_dataset method in src/utils/data/data_utils.py to utilize the custom dataset.

SFT Stage

Training

In the SFT stage, the LM is fine-tuned with supervised MLE on the data that is supposed to be obtained from the same distribution as the preference data. One can simply finetune on the chosen texts of the preference data if no such data is available.

SFT data format ```json { "prompt": "prompt", "chosen": "chosen text" } ```
Training script ```bash # Any causal HuggingFace model (`AutoModelForCausalLM` class) INIT_MODEL_NAME=custom-model # local path to the checkpoint of the initial model INIT_MODEL_PATH=/local/path/to/init/model # type of the model MODEL_TYPE=sft # name of the sft data, default format: "name/sft", should be added to `src/utils/data/data_utils.py` DATA_NAME=custom-data/sft # local path to the sft data DATA_PATH=/local/path/to/sft/data bash exp/custom_exp/train_sft.sh $INIT_MODEL_NAME $INIT_MODEL_PATH $MODEL_TYPE $DATA_NAME $DATA_PATH ``` Other hyperparameters for training can be specified in `exp/custom_exp/train_sft.sh`. The SFT model will be saved in `models/custom-model_custom-data/sft`.

Inference

(Optional but recommended) To utilize supervision of the reward model for alignment, one need to sample from the SFT model and later use the reward model to score the inference results.

Inference script ```bash # comma separated device ids DEVICE_IDS=0,1,2,3 # data name and data path concatenated by colon DATA_NAME_PATH=custom_data/sft:/local/path/to/sft/data # local path to SFT model MODEL_PATH=models/custom-model_custom-data/sft # inference on train set SPLIT=train bash exp/custom_exp/inference_sft.sh $DEVICE_IDS $DATA_NAME_PATH $SPLIT $MODEL_PATH # inference on test set SPLIT=test bash exp/custom_exp/inference_sft.sh $DEVICE_IDS $DATA_NAME_PATH $SPLIT $MODEL_PATH ``` Other hyperparameters for decoding can be specified in `exp/custom_exp/inference_sft.sh`. The inference results will be saved under the same root directory of the SFT data.
SFT generated data format ```json { "prompt": "prompt", "completions": ["text A", "text B", ...] } ```

RM stage

Training

(Optional but recommended) In order to utilize the continuous preference signal, one can train a reward model on the preference data to predict the human preference.

Preference data format ```json { "prompt": "prompt", "chosen": "chosen text", "rejected": "rejected text" } ```
Training script ```bash # Any HuggingFace model (`AutoModel` class), use the last position of the sequence for prediction INIT_MODEL_NAME=custom-model # local path to the checkpoint of the initial model INIT_MODEL_PATH=/local/path/to/init/model # type of the model MODEL_TYPE=rm # name of the preference data, default format: "name/pref", should be added to `src/utils/data/data_utils.py` DATA_NAME=custom-data/pref # local path to the pref data DATA_PATH=/local/path/to/pref/data bash exp/custom_exp/train_rm.sh $INIT_MODEL_NAME $INIT_MODEL_PATH $MODEL_TYPE $DATA_NAME $DATA_PATH ``` Other hyperparameters for training can be specified in `exp/custom_exp/train_rm.sh`. The SFT model will be saved in `models/custom-model_custom-data/rm`.

Inference

(Optional but recommended) Then use the reward model to score the SFT generated data with continuous reward.

Inference script ```bash # comma separated device ids DEVICE_IDS=0,1,2,3 # local path to the sft generated data DATA_PATH=/local/path/to/sft/gen/data # local path to the reward model MODEL_PATH=models/custom-model_custom-data/rm # inference on train set SPLIT=train bash exp/custom_exp/inference_rm.sh $DEVICE_IDS $DATA_NAME_PATH $SPLIT $MODEL_PATH # inference on test set SPLIT=test bash exp/custom_exp/inference_rm.sh $DEVICE_IDS $DATA_NAME_PATH $SPLIT $MODEL_PATH ``` Other hyperparameters for inference can be specified in `exp/custom_exp/inference_rm.sh`. The inference results will be saved under the same root directory of the SFT data.
RM labeled data format ```json { "prompt": "prompt", "completions": ["text A", "text B", ...], "rewards": [reward A, reward B, ...] } ```

Alignment stage

EXO & DPO

In the alignment stage, the SFT model is fine-tuned to align with human preferences by training on either the preference data or the RM labeled data.

Before training, the preference dataset should be converted to the same format as the RM labeled data:

python src/utils/data/pref_to_rw.py /local/path/to/preference/data
Training script To train the policy using the EXO algorithm, run the following commands: ```bash # Any causal HuggingFace model (`AutoModelForCausalLM` class) INIT_MODEL_NAME=custom-model # local path to the SFT model INIT_MODEL_PATH=/local/path/to/sft/model # type of the model MODEL_TYPE=align # name of the reward data, default format: "name/rw", should be added to `src/utils/data/data_utils.py` DATA_NAME=custom-data/rw # local path to the reward data or preference data DATA_PATH=/local/path/to/rw/data # supported loss type: exo-pref / exo-rw / dpo-pref / dpo-rw LOSS_TYPE="exo-pref" # number of contrastive samples, should not be greater than the number of completion candidates in the dataset. NUM_CONTRASTIVE=2 bash exp/custom_exp/train_exo.sh $INIT_MODEL_NAME $INIT_MODEL_PATH $MODEL_TYPE $DATA_NAME $DATA_PATH $LOSS_TYPE $NUM_CONTRASTIVE ``` Other hyperparameters for training can be specified in `exp/custom_exp/train_exo.sh`. To train the policy using the DPO algorithm, simply change the `LOSS_TYPE` to either `dpo-pref` or `dpo-rw`.

Reproducing experiments in the paper

We also provide guidelines to reproduce the experiments on the three public datasets: IMDB, TL;DR and Anthropic-HH to facilitate future study:

Citing

@article{Ji2024TowardsExact,
  title={Towards Efficient Exact Optimization of Language Model Alignment},
  author={Haozhe Ji, Cheng Lu, Yilin Niu, Pei Ke, Hongning Wang, Jun Zhu, Jie Tang, Minlie Huang},
  year={2024},
  journal={The Forty-first International Conference on Machine Learning},
  url={https://arxiv.org/abs/2402.00856}
}

Please kindly cite our work if you find the paper or this repository useful :)