princeton-nlp / SimPO

SimPO: Simple Preference Optimization with a Reference-Free Reward
503 stars 29 forks source link
alignment large-language-models preference-alignment rlhf

Simple Preference Optimization (SimPO)

This repository contains the code and released models for our paper SimPO: Simple Preference Optimization with a Reference-Free Reward. We propose a simpler and more effective preference optimization algorithm than DPO (Direct Preference Optimization) without using a reference model. SimPO outperforms DPO and its latest variants across AlpacaEval 2, MT-Bench, and Arena-Hard benchmarks under various settings.

🆕 Changelog

🔗 Quick Links

Tips for Running SimPO

Given the various inquiries about SimPO, we provide a list of tips to help you reproduce our paper results and achieve better outcomes for running SimPO on your own tasks.

Hyperparameter tuning

Hyperparameter tuning is crucial for SimPO (and other preference optimization algorithms in general). The three main hyperparameters of SimPO to focus on are learning_rate, beta, and gamma.

We used the following hyperparameters for training the released models (note that in our latest update, we changed the hyperparameter gamma to gamma_beta_ratio as the latter is normalized and easier to tune under different beta values). Setting β γ/β Learning rate
Mistral-Base 2.0 0.8 3e-7
Mistral-Instruct 2.5 0.1 5e-7
Llama3-Base 2.0 0.5 6e-7
Llama3-Instruct 2.5 0.55 1e-6
Llama3-Instruct v0.2 10 0.3 1e-6
For DPO, we use the following hyperparameters for training. Setting β Learning Rate
Mistral-Base 0.01 5e-7
Mistral-Instruct 0.01 5e-7
Llama3-Base 0.01 5e-7
Llama3-Instruct 0.01 7e-7
Llama3-Instruct v0.2 0.01 3e-7

Training and evaluation consistency in BOS

Our released Llama3 models use the initial version of the Llama3 tokenizer (prior to this PR). We have found that the updated Llama3 tokenizer with vLLM occasionally introduces two BOS tokens, which can affect evaluation results. Therefore, please ensure that only one BOS token is included in the prompt after applying the Llama3 chat template during any evaluation.

Notably, if you are training Llama3 and evaluating the trained models on AlpacaEval 2 and Arena-Hard using the templates provided in this repo, please make sure to use the pre-update Llama3 tokenizer (i.e., the one before the PR).

Adding an extra sft loss

We have observed that, in some cases, adding an additional SFT loss can help improve results. These findings have been initially validated in the CPO_SIMPO repository. We are currently working on integrating this improvement into our main repository.

Released Models

v0.1

Below is the complete list of models evaluated in our preprint. We used the HuggingFaceH4/ultrafeedback_binarized dataset to train the Mistral Base and Llama3 Base models, the princeton-nlp/mistral-instruct-ultrafeedback dataset to train the Mistral Instruct models, and the princeton-nlp/llama3-ultrafeedback dataset to train the Llama3 Instruct models. The latter two datasets are annotated by the llm-blender/PairRM model.

models AE2 LC AE2 WR AH
Mistral Base 7B SFT alignment-handbook/zephyr-7b-sft-full 8.4 6.2 1.3
Mistral Base 7B RRHF princeton-nlp/Mistral-7B-Base-SFT-RRHF 11.6 10.2 6.9
Mistral Base 7B SLiC-HF princeton-nlp/Mistral-7B-Base-SFT-SLiC-HF 10.9 8.9 7.3
Mistral Base 7B DPO (Zephyr) princeton-nlp/Mistral-7B-Base-SFT-DPO 15.1 12.5 10.4
Mistral Base 7B IPO princeton-nlp/Mistral-7B-Base-SFT-IPO 11.8 9.4 7.5
Mistral Base 7B CPO princeton-nlp/Mistral-7B-Base-SFT-CPO 9.8 8.9 6.9
Mistral Base 7B KTO princeton-nlp/Mistral-7B-Base-SFT-KTO 13.1 9.1 5.6
Mistral Base 7B ORPO kaist-ai/mistral-orpo-beta 14.7 12.2 7.0
Mistral Base 7B R-DPO princeton-nlp/Mistral-7B-Base-SFT-RDPO 17.4 12.8 9.9
Mistral Base 7B SimPO princeton-nlp/Mistral-7B-Base-SFT-SimPO 21.4 20.8 16.6
Mistral Instruct 7B SFT mistralai/Mistral-7B-Instruct-v0.2 17.1 14.7 12.6
Mistral Instruct 7B RRHF princeton-nlp/Mistral-7B-Instruct-RRHF 25.3 24.8 18.1
Mistral Instruct 7B SLiC-HF princeton-nlp/Mistral-7B-Instruct-SLiC-HF 24.1 24.6 18.9
Mistral Instruct 7B DPO princeton-nlp/Mistral-7B-Instruct-DPO 26.8 24.9 16.3
Mistral Instruct 7B IPO princeton-nlp/Mistral-7B-Instruct-IPO 20.3 20.3 16.2
Mistral Instruct 7B CPO princeton-nlp/Mistral-7B-Instruct-CPO 23.8 28.8 22.6
Mistral Instruct 7B KTO princeton-nlp/Mistral-7B-Instruct-KTO 24.5 23.6 17.9
Mistral Instruct 7B ORPO princeton-nlp/Mistral-7B-Instruct-ORPO 24.5 24.9 20.8
Mistral Instruct 7B R-DPO princeton-nlp/Mistral-7B-Instruct-RDPO 27.3 24.5 16.1
Mistral Instruct 7B SimPO princeton-nlp/Mistral-7B-Instruct-SimPO 32.1 34.8 21.0
Llama3 Base 8B SFT princeton-nlp/Llama-3-Base-8B-SFT 6.2 4.6 3.3
Llama3 Base 8B RRHF princeton-nlp/Llama-3-Base-8B-RRHF 10.8 8.1 6.6
Llama3 Base 8B SLiC-HF princeton-nlp/Llama-3-Base-8B-SLiC-HF 12.1 10.1 10.3
Llama3 Base 8B DPO princeton-nlp/Llama-3-Base-8B-SFT-DPO 18.2 15.5 15.9
Llama3 Base 8B IPO princeton-nlp/Llama-3-Base-8B-SFT-IPO 14.4 14.2 17.8
Llama3 Base 8B CPO princeton-nlp/Llama-3-Base-8B-SFT-CPO 10.8 8.1 5.8
Llama3 Base 8B KTO princeton-nlp/Llama-3-Base-8B-SFT-KTO 14.2 12.4 12.5
Llama3 Base 8B ORPO princeton-nlp/Llama-3-Base-8B-SFT-ORPO 12.2 10.6 10.8
Llama3 Base 8B R-DPO princeton-nlp/Llama-3-Base-8B-SFT-RDPO 17.6 14.4 17.2
Llama3 Base 8B SimPO princeton-nlp/Llama-3-Base-8B-SFT-SimPO 22.0 20.3 23.4
Llama3 Instruct 8B SFT meta-llama/Meta-Llama-3-Instruct-8B 26.0 25.3 22.3
Llama3 Instruct 8B RRHF princeton-nlp/Llama-3-Instruct-8B-RRHF 31.3 28.4 26.5
Llama3 Instruct 8B SLiC-HF princeton-nlp/Llama-3-Instruct-8B-SLiC-HF 26.9 27.5 26.2
Llama3 Instruct 8B DPO princeton-nlp/Llama-3-Instruct-8B-DPO 40.3 37.9 32.6
Llama3 Instruct 8B IPO princeton-nlp/Llama-3-Instruct-8B-IPO 35.6 35.6 30.5
Llama3 Instruct 8B CPO princeton-nlp/Llama-3-Instruct-8B-CPO 33.1 31.8 26.4
Llama3 Instruct 8B KTO princeton-nlp/Llama-3-Instruct-8B-KTO 33.1 31.8 26.4
Llama3 Instruct 8B ORPO princeton-nlp/Llama-3-Instruct-8B-ORPO 28.5 27.4 25.8
Llama3 Instruct 8B R-DPO princeton-nlp/Llama-3-Instruct-8B-RDPO 41.1 37.8 33.1
Llama3 Instruct 8B SimPO princeton-nlp/Llama-3-Instruct-8B-SimPO 44.7 40.5 33.8

v0.2

We found that using a strong reward model for annotating preference optimization datasets is crucial. In this iteration, we have reannotated the dataset princeton-nlp/llama3-ultrafeedback-armorm using a more powerful reward model, RLHFlow/ArmoRM-Llama3-8B-v0.1. As a result, the v0.2 models demonstrate significantly improved performance compared to the v0.1 models.

models AE2 LC AE2 WR AH
Llama 3 Instruct 8B RRHF v0.2 princeton-nlp/Llama-3-Instruct-8B-RRHF-v2.0 37.9 31.6 28.8
Llama 3 Instruct 8B SLiC-HF v0.2 princeton-nlp/Llama-3-Instruct-8B-SLiC-HF-v2.0 33.9 32.5 29.3
Llama 3 Instruct 8B DPO v0.2 princeton-nlp/Llama-3-Instruct-8B-DPO-v0.2 48.2 47.5 35.2
Llama 3 Instruct 8B IPO v0.2 princeton-nlp/Llama-3-Instruct-8B-IPO-v0.2 46.8 42.4 36.6
Llama 3 Instruct 8B CPO v0.2 princeton-nlp/Llama-3-Instruct-8B-CPO-v0.2 34.1 36.4 30.9
Llama 3 Instruct 8B KTO v0.2 princeton-nlp/Llama-3-Instruct-8B-KTO-v0.2 34.1 32.1 27.3
Llama 3 Instruct 8B ORPO v0.2 princeton-nlp/Llama-3-Instruct-8B-ORPO-v0.2 38.1 33.8 28.2
Llama 3 Instruct 8B R-DPO v0.2 princeton-nlp/Llama-3-Instruct-8B-RDPO-v0.2 48.0 45.8 35.1
Llama 3 Instruct 8B SimPO v0.2 princeton-nlp/Llama-3-Instruct-8B-SimPO-v0.2 53.7 47.5 36.5

Use our models for inference

Please refer to the generate.py script for detailed instructions on loading the model with the appropriate chat template.

Install Requirements

Our codebase is built upon the alignment-handbook repo. The following steps will guide you through the installation process.

First, create a Python virtual environment using e.g. Conda:

conda create -n handbook python=3.10 && conda activate handbook

Next, install PyTorch v2.2.2. Since this is hardware-dependent, we direct you to the PyTorch Installation Page.

You can then install the remaining package dependencies of alignment-handbook as follows:

git clone https://github.com/huggingface/alignment-handbook.git
cd ./alignment-handbook/
python -m pip install .

You will also need Flash Attention 2 installed, which can be done by running:

python -m pip install flash-attn --no-build-isolation

Training Scripts

We provide four training config files for the four training setups reported in our paper. The training config is set for 4xH100 GPUs. You may need to adjust num_processes and per_device_train_batch_size based on your computation environment.

Evaluation

We follow the official implementation for evaluation on AlpacaEval 2, Arena-Hard, and MT-Bench, as follows (more details can be found under the eval directory):

Bugs or Questions?

If you have any questions related to the code or the paper, feel free to email Yu (yumeng5@virginia.edu). If you encounter any problems when using the code, or want to report a bug, feel free to open an issue! Please try to specify the problem with details so we can help you better and quicker!

Citation

Please cite our paper if you find the repo helpful in your work:

@article{meng2024simpo,
  title={{SimPO}: Simple Preference Optimization with a Reference-Free Reward},
  author={Meng, Yu and Xia, Mengzhou and Chen, Danqi},
  journal={arXiv preprint arXiv:2405.14734},
  year={2024}
}