vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
145 stars 7 forks source link

Jax reward learning #13

Closed liutianlin0121 closed 1 year ago

liutianlin0121 commented 1 year ago

Hey Costa,

Thanks for your help on the jax dependency---I can now successfully run jax & pytorch reward learning within the same poetry environment!

Here are my current reproduction results (pytorch vs jax) for reward learning: https://wandb.ai/tliu/lm_human_preference_details

Screenshot 2023-08-25 at 15 37 08

A few comments:

Let me know if you have any suggestions or feedback! If you have 8 spare GPUs, it would be great if you could run the jax version on them, and see if the results check out with previous pytorch results. Thanks!

Tianlin

liutianlin0121 commented 1 year ago

Thanks for your comments! I'll work on them.

liutianlin0121 commented 1 year ago

Thanks for the comments! I ran the newly updated Pytorch and Jax version with and without tf-style adam, each with 5 random seeds. Below are the results. The results from Jax and Pytorch versions match nicely in both cases (with tf-style adam and without)!

Screenshot from 2023-08-27 10-32-42 Screenshot from 2023-08-27 10-33-13

Here are the full wandb logs. You may want to group the runs by exp_name and use_tensorflow_adam.

A few more comments:

Thanks!

vwxyzjn commented 1 year ago

This will bring two versions closer at the expense of introducing more abstractions in the pytorch version, so I didn't do it for now.

Sounds good. Let's keep the current version as is for now :)

As mentioned in a previous comment, for now I used args.world_size = len(jax.devices())

I changed it to

args.batch_size = int(args.local_batch_size * len(jax.local_devices()) * args.world_size)

I'm running some benchmark experiments now.

Here are the full wandb logs.

Awesome! Btw there is actually a more automated way to do this.

export WANDB_ENTITY=openrlbenchmark
# sentiment
WANDB_TAGS="tf_adam,gpt2" python benchmark/benchmark.py \
    --command "python lm_human_preference_details/train_reward_jax.py --save_path '' --track --wandb_project_name=lm_human_preference_details" \
    --num-seeds 10 \
    --start-seed 1 \
    --workers 10 \
    --slurm-gpus-per-task 8 \
    --slurm-ntasks 1 \
    --slurm-total-cpus 64 \
    --slurm-template-path benchmark/trl.slurm_template
# descriptiveness
WANDB_TAGS="tf_adam,gpt2" python benchmark/benchmark.py \
    --command "python lm_human_preference_details/train_reward_jax.py --save_path '' --label_dataset=descriptiveness/offline_5k.json --track --wandb_project_name=lm_human_preference_details"  \
    --num-seeds 10 \
    --start-seed 1 \
    --workers 10 \
    --slurm-gpus-per-task 8 \
    --slurm-ntasks 1 \
    --slurm-total-cpus 64 \
    --slurm-template-path benchmark/trl.slurm_template
python -m openrlbenchmark.rlops_multi_metrics \
    --filters '?we=openrlbenchmark&wpn=lm-human-preferences&xaxis=_step&ceik=task_id&cen=task.value.policy.initial_model&metrics=train_reward/minibatch/error' \
        '124M' \
    --filters '?we=openrlbenchmark&wpn=lm_human_preference_details&xaxis=_step&ceik=label_dataset&cen=exp_name&metrics=train/loss' \
        'train_reward_accelerate?tag=v0.1.0-76-gfbf1f0c&tag=tf_adam&tag=gpt2&cl=tf_adam,gpt2' \
        'train_reward_jax?tag=v0.1.0-75-g8cc6065&tag=tf_adam&tag=gpt2&cl=jax,tf_adam,gpt2' \
    --env-ids sentiment  \
    --env-ids sentiment/offline_5k.json  \
    --no-check-empty-runs \
    --pc.ncols 2 \
    --pc.ncols-legend 1 \
    --pc.max_steps 250 \
    --output-filename static/0compare \
    --scan-history --report
image

https://wandb.ai/costa-huang/cleanrl/reports/Regression-Report-train_reward_jax--Vmlldzo1MjQ0ODIw

image

https://colab.research.google.com/drive/1ZsZ2f440kBq74ZioJLInjwLlGb9QYVTH?usp=sharing

vwxyzjn commented 1 year ago

Gonna go out for errands now. train_policy_accelerate.py seems broken — I will try fix it after I come back. Meanwhile the results look amazing!

You can probably start working on the train_policy_jax.py :) https://gist.github.com/vwxyzjn/7005a81ba39deb3bc8043041bd715be1 should be very helpful.

We can merge the PR with or without train_policy_jax.py. Either way works for me.

liutianlin0121 commented 1 year ago

Awesome! Just merged. I'll proceed to work on the policy learning part 😀