sail-sg / oat

๐ŸŒพ OAT: Online AlignmenT for LLMs
https://arxiv.org/pdf/2411.01493
Apache License 2.0
29 stars 2 forks source link
alignment distributed-rl distributed-training dpo dueling-bandits llm llm-aligment llm-exploration online-alignment rlhf thompson-sampling

OAT

PyPI - Version PyPI - Python Version License arXiv

Installation | Usage | Examples | Benchmarking | Citation


Introduction

Oat ๐ŸŒพ is a simple yet efficient system for running online LLM alignment algorithms. Its key features include:

LLM alignment as contextual dueling bandits

LLM alignment is essentially an online learning and decision making problem where the agent (e.g., the LLM policy with an optional built-in reward model) interacts with the environment (i.e., humans) to achieve either of the two distinct objectives: minimizing cumulative regret in the Explore & Exploit setting or minimizing anytime regret in the Best Arm Identification setting.

In our paper, we formalize LLM alignment as a contextual dueling bandit (CDB) problem (see illustration below) and propose a sample-efficient alignment approach based on Thompson sampling.

The CDB framework necessitates an efficient online training system to validate the proposed method and compare it with other baselines. Oat ๐ŸŒพ is developed as part of this research initiative.

Using the CDB framework, existing LLM alignment paradigms can be summarized as follows:

For more details, please check out our paper!

Installation

In a python environment with supported versions (>=3.8, <=3.10), you could install oat via PyPI:

pip install vllm==0.6.2 && pip install oat-llm

Or you could also install in "editable" mode for local development:

git clone git@github.com:sail-sg/oat.git
cd oat
pip install vllm==0.6.2 && pip install -e .

Usage

Below is an example to align a 1-B Pythia SFT Model on the tl;dr dataset using online SimPO with PairRM as the preference oracle:

[!WARNING] Aligning with PairRM provides a lightweight example setup. For reproducing results from the paper or developing custom online alignment algorithms, we recommend using stronger reward models (or GPT-as-a-judge) as a preference oracle. This approach better approximates the ideal case of a human population. See the examples.

python -m oat.experiment.main \
    --gpus 2 \
    --collocate \
    --dap-algo SimPO \
    --beta 2 \
    --preference-oracle pairrm \
    --pretrain trl-lib/pythia-1b-deduped-tldr-sft \
    --prompt-data lkevinzc/tldr-with-sft-reference \
    --output_key pythia-1b-reference \
    --sync-params-every 1 \
    --rollout-batch-size-per-device 64 \
    --pi-buffer-maxlen-per-device 64 \
    --train-batch-size-per-device 8 \
    --use-wb \
    --wb-run-name 1b_pairrm_simpo_online

This example completes in less than two hours on two A100-40G GPUs!

To run an offline SimPO baseline for comparison, we disable weights synchronization from the learner to actors by adjusting the sync-params-every argument:

python -m oat.experiment.main \
    --gpus 2 \
    --collocate \
    --dap-algo SimPO \
    --beta 2 \
    --preference-oracle pairrm \
    --pretrain trl-lib/pythia-1b-deduped-tldr-sft \
    --prompt-data lkevinzc/tldr-with-sft-reference \
    --output_key pythia-1b-reference \
-   --sync-params-every 1 \
+   --sync-params-every 9999 \ # any number > total gradient step (50000//128=390)
    --rollout-batch-size-per-device 64 \
    --pi-buffer-maxlen-per-device 64 \
    --train-batch-size-per-device 8 \
    --use-wb \
-   --wb-run-name 1b_pairrm_simpo_online
+   --wb-run-name 1b_pairrm_simpo_offline

Finally, we run SEA SimPO (with $\gamma=1$, see here for the meaning of $\gamma$) to verify its capability of sample-efficient alignment. This experiment utilizes 4 GPUs, with a reduced per-device training batch size to accommodate the training of an additional epistemic reward model. The per-device rollout batch size and buffer length are adjusted to ensure a global batch size of 128. Additionally, 10 response candidates are generated for exploration using BAI Thompson sampling.

python -m oat.experiment.main \
-   --gpus 2 \
+   --gpus 4 \
    --dap-algo SimPO \
    --beta 2 \
    --preference-oracle pairrm \
    --pretrain trl-lib/pythia-1b-deduped-tldr-sft \
    --prompt-data lkevinzc/tldr-with-sft-reference \
    --output_key pythia-1b-reference \
    --sync-params-every 1 \
-   --rollout-batch-size-per-device 64 \
-   --pi-buffer-maxlen-per-device 64 \
-   --train-batch-size-per-device 8 \
+   --rollout-batch-size-per-device 32 \
+   --pi-buffer-maxlen-per-device 32 \
+   --train-batch-size-per-device 1 \
+   --learn-rm \
+   --exp-method EnnBAITS \
+   --num_samples 10 \
    --use-wb \
-   --wb-run-name 1b_pairrm_simpo_online
+   --wb-run-name 1b_pairrm_simpo_sea

Check out this tutorial for more examples covering:

Benchmarking

The benchmarking compares oat with the online DPO implementation from huggingface/trl. Below, we outline the configurations used for oat and present the benchmarking results. Notably, oat ๐ŸŒพ achieves up to 2.5x computational efficiency compared to trl ๐Ÿค—.

Please refer to Appendix C of our paper for a detailed discussion of the benchmarking methods and results.

Citation

If you find this work useful for your research, please consider citing

@article{
  liu2024sea,
  title={Sample-Efficient Alignment for LLMs},
  author={Zichen Liu and Changyu Chen and Chao Du and Wee Sun Lee and Min Lin},
  journal={arXiv preprint arXiv:2411.01493},
  year={2024}
}

License

oat is distributed under the terms of the Apache2 license.

Acknowledgement

We thank the following awesome projects that have contributed to the development of oat:

Disclaimer

This is not an official Sea Limited or Garena Online Private Limited product.