vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
152 stars 7 forks source link

lm-human-preference-details

This repo aims to make a blog post similar to The 37 Implementation Details of Proximal Policy Optimization but for RLHF techniques used in https://github.com/openai/lm-human-preferences.

Warning This repo is a WIP made public because it's easier for me to share pointers with collaborators. I'll remove this warning when the repo is ready for public consumption.

The goal of the repo is 1) to provide a simple-to-read and minimal reference implementation of RLHF and 2) to create rigorous benchmarks and to match the learning curves of openai/lm-human-preferences.

This repo is just for educational / learning purposes. For more advanced users, https://github.com/lvwerra/trl would be a great choice.

Get started

poetry install
poetry shell
accelerate launch \
    --num_processes 8 \
    lm_human_preference_details/train_both_accelerate.py \
    --reward.track --policy.track
accelerate launch \
    --num_processes 8 \
    lm_human_preference_details/train_both_accelerate.py \
    --reward.track \
    --reward.label_dataset=descriptiveness/offline_5k.json \
    --policy.track

You can also run stuff individually. For example, to train the reward model, run

accelerate launch \
    --num_processes 8 \
    lm_human_preference_details/train_reward_accelerate.py \
    --track

to train the policy model, run

accelerate launch \
    --num_processes 8 \
    lm_human_preference_details/train_policy_accelerate.py \
    --track

⚠️ NOTE: You can install the latest torch or jax with the following command:

poetry run pip install torch==2.0.1
poetry run pip install "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run pip install git+https://github.com/huggingface/huggingface_hub@main

Current status

Currently, for reproduction, I used the same dataset, data processing pipeline, hyperparameters, and initial model architecture and weights (gpt2 124M pretrained model).

The following charts show the learning curves of various metrics for sentiment and descriptiveness tasks, each with 10 random seeds of commit 46725b.

pip install openrlbenchmark==0.2.1a4
python -m openrlbenchmark.rlops_multi_metrics \
    --filters '?we=openrlbenchmark&wpn=lm-human-preferences&xaxis=_step&ceik=task_id&cen=task.value.policy.initial_model&metrics=ppo/objective/score&metrics=ppo/objective/kl&metrics=ppo/objective/entropy&metrics=ppo/objective/score_total&metrics=ppo/objective/kl_coef&metrics=ppo/ppo/loss/total&metrics=ppo/ppo/loss/value&metrics=ppo/ppo/loss/policy&metrics=ppo/ppo/policy/clipfrac&metrics=ppo/ppo/policy/entropy&metrics=ppo/ppo/returns/mean&metrics=ppo/ppo/policy/approxkl&metrics=ppo/ppo/val/clipfrac&metrics=ppo/ppo/val/error&metrics=ppo/ppo/val/mean&metrics=ppo/ppo/returns/var&metrics=ppo/ppo/val/vpred' \
        '124M' \
    --filters '?we=openrlbenchmark&wpn=lm_human_preference_details&xaxis=_step&ceik=rewards.value.label_dataset&cen=exp_name&metrics=objective/scores&metrics=objective/kl&metrics=objective/entropy&metrics=objective/score_total&metrics=objective/kl_coef&metrics=ppo/loss/total&metrics=ppo/loss/value&metrics=ppo/loss/policy_avg&metrics=ppo/policy/clipfrac_avg&metrics=ppo/policy/entropy_avg&metrics=ppo/returns/mean&metrics=ppo/policy/approxkl_avg&metrics=ppo/val/clipfrac_avg&metrics=ppo/val/error&metrics=ppo/val/mean&metrics=ppo/returns/var&metrics=ppo/val/vpred' \
        'train_policy_accelerate?tag=v0.1.0-58-g4f42012&tag=tf_adam&tag=gpt2&cl=tf_adam,gpt2' \
    --env-ids sentiment descriptiveness \
    --env-ids sentiment/offline_5k.json  descriptiveness/offline_5k.json \
    --no-check-empty-runs \
    --pc.ncols 6 \
    --pc.ncols-legend 1 \
    --output-filename static/0compare \
    --scan-history  --report
# (optionally) you can add `--report`` to generate wandb report

Wandb report is available at https://wandb.ai/costa-huang/cleanrl/reports/Regression-Report-train_policy_accelerate--Vmlldzo1MTEwMzQw. Feel free to check out the logs of the runs for sample outputs.

Sentiment

Descriptiveness

Learning curves of openai/lm-human-preferences

Wandb report is here: https://wandb.ai/costa-huang/cleanrl/reports/Regression-Report-124M--Vmlldzo0ODM3NTI5

pip install openrlbenchmark==0.2.1a4
python -m openrlbenchmark.rlops_multi_metrics \
    --filters '?we=openrlbenchmark&wpn=lm-human-preferences&ceik=task_id&cen=task.value.policy.initial_model&metrics=ppo/objective/score&metrics=ppo/objective/kl&metrics=ppo/ppo/loss/policy&metrics=ppo/ppo/val/mean&metrics=ppo/ppo/policy/entropy&metrics=ppo/ppo/policy/approxkl&metrics=ppo/ppo/val/error&metrics=ppo/ppo/loss/total&metrics=ppo/ppo/returns/mean&metrics=train_reward/minibatch/loss&metrics=ppo/ppo/val/vpred&metrics=ppo/ppo/loss/value&metrics=ppo/ppo/val/var_explained&metrics=ppo/objective/score_total&metrics=train_reward/minibatch/error&metrics=ppo/elapsed/fps&metrics=ppo/global_step&metrics=ppo/ppo/policy/clipfrac&metrics=ppo/ppo/val/var&metrics=ppo/ppo/val/clipfrac&metrics=ppo/objective/entropy&metrics=ppo/ppo/returns/var&metrics=ppo/objective/kl_coef&metrics=ppo/elapsed/time' \
        '124M' \
    --env-ids sentiment descriptiveness tldr \
    --check-empty-runs \
    --pc.ncols 5 \
    --pc.ncols-legend 1 \
    --output-filename static/0compare \
    --scan-history --report

Acknowledgement

This work is supported by 🤗 Hugging Face's Big Science A100 cluster.