vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
152 stars 7 forks source link

Creating a jax implementation #5

Closed vwxyzjn closed 1 year ago

vwxyzjn commented 1 year ago

Currently, we can already near perfect learning curves matching with the pytorch implementation.

image

It will be helpful if we can replicate this success using JAX. Here is a helpful snippet for Jax policy training. https://gist.github.com/vwxyzjn/7005a81ba39deb3bc8043041bd715be1

It's probably easier to translate https://github.com/vwxyzjn/lm-human-preference-details/blob/main/lm_human_preference_details/train_reward_accelerate.py to jax. Trying to match the reward model learning curves.

pip install openrlbenchmark==0.2.1a4
python -m openrlbenchmark.rlops_multi_metrics \
    --filters '?we=openrlbenchmark&wpn=lm-human-preferences&xaxis=_step&ceik=task_id&cen=task.value.policy.initial_model&metrics=train_reward/minibatch/error' \
        '124M' \
    --filters '?we=openrlbenchmark&wpn=lm_human_preference_details&xaxis=_step&ceik=label_dataset&cen=exp_name&metrics=train/loss' \
        'train_reward_accelerate?tag=v0.1.0-58-g4f42012&tag=tf_adam&tag=gpt2&cl=tf_adam,gpt2' \
    --env-ids sentiment descriptiveness \
    --env-ids sentiment/offline_5k.json  descriptiveness/offline_5k.json \
    --no-check-empty-runs \
    --pc.ncols 2 \
    --pc.ncols-legend 1 \
    --pc.max_steps 250 \
    --output-filename static/0compare \
    --scan-history
image