Closed liutianlin0121 closed 1 year ago
Thanks for your comments! I'll work on them.
Thanks for the comments! I ran the newly updated Pytorch and Jax version with and without tf-style adam, each with 5 random seeds. Below are the results. The results from Jax and Pytorch versions match nicely in both cases (with tf-style adam and without)!
Here are the full wandb logs. You may want to group the runs by exp_name
and use_tensorflow_adam
.
A few more comments:
train_reward_jax.py
and the pytorch version train_reward_accelerate.py
to align the two versions more closely. Feel free to modify them directly or suggest further edits. There are more things can be done, but I figure it is better to ask your opinion first. One thing that we could do is to incorporate the prepare_left_padded_query_responses_with_labels
from the jax version into the pytorch version. This will bring two versions closer at the expense of introducing more abstractions in the pytorch version, so I didn't do it for now.pre-commit run --all-files
as you suggested, so there are changes occur in files other than train_reward_jax.py
and train_reward_accelerate.py
. Pre-commit looks very useful and I didn't know it before!args.world_size = len(jax.devices())
, and subsequently used args.batch_size = int(args.local_batch_size * args.world_size)
to adjust the batch size based on the number of GPUs. This makes the pytorch and jax versions equivalent in practice. That being said, I agree that world_size
can be a bit confusing here: as you pointed out, for jax, if we only have one machine, then world_size
should technically be 1, whereas for now we have assigned world_size = # GPUs the machine has
. If you have any suggestions for name change, do let me know! Thanks!
This will bring two versions closer at the expense of introducing more abstractions in the pytorch version, so I didn't do it for now.
Sounds good. Let's keep the current version as is for now :)
As mentioned in a previous comment, for now I used args.world_size = len(jax.devices())
I changed it to
args.batch_size = int(args.local_batch_size * len(jax.local_devices()) * args.world_size)
I'm running some benchmark experiments now.
Here are the full wandb logs.
Awesome! Btw there is actually a more automated way to do this.
export WANDB_ENTITY=openrlbenchmark
# sentiment
WANDB_TAGS="tf_adam,gpt2" python benchmark/benchmark.py \
--command "python lm_human_preference_details/train_reward_jax.py --save_path '' --track --wandb_project_name=lm_human_preference_details" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 8 \
--slurm-ntasks 1 \
--slurm-total-cpus 64 \
--slurm-template-path benchmark/trl.slurm_template
# descriptiveness
WANDB_TAGS="tf_adam,gpt2" python benchmark/benchmark.py \
--command "python lm_human_preference_details/train_reward_jax.py --save_path '' --label_dataset=descriptiveness/offline_5k.json --track --wandb_project_name=lm_human_preference_details" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 8 \
--slurm-ntasks 1 \
--slurm-total-cpus 64 \
--slurm-template-path benchmark/trl.slurm_template
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=openrlbenchmark&wpn=lm-human-preferences&xaxis=_step&ceik=task_id&cen=task.value.policy.initial_model&metrics=train_reward/minibatch/error' \
'124M' \
--filters '?we=openrlbenchmark&wpn=lm_human_preference_details&xaxis=_step&ceik=label_dataset&cen=exp_name&metrics=train/loss' \
'train_reward_accelerate?tag=v0.1.0-76-gfbf1f0c&tag=tf_adam&tag=gpt2&cl=tf_adam,gpt2' \
'train_reward_jax?tag=v0.1.0-75-g8cc6065&tag=tf_adam&tag=gpt2&cl=jax,tf_adam,gpt2' \
--env-ids sentiment \
--env-ids sentiment/offline_5k.json \
--no-check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--pc.max_steps 250 \
--output-filename static/0compare \
--scan-history --report
https://wandb.ai/costa-huang/cleanrl/reports/Regression-Report-train_reward_jax--Vmlldzo1MjQ0ODIw
https://colab.research.google.com/drive/1ZsZ2f440kBq74ZioJLInjwLlGb9QYVTH?usp=sharing
Gonna go out for errands now. train_policy_accelerate.py
seems broken — I will try fix it after I come back. Meanwhile the results look amazing!
You can probably start working on the train_policy_jax.py
:) https://gist.github.com/vwxyzjn/7005a81ba39deb3bc8043041bd715be1 should be very helpful.
We can merge the PR with or without train_policy_jax.py
. Either way works for me.
Awesome! Just merged. I'll proceed to work on the policy learning part 😀
Hey Costa,
Thanks for your help on the jax dependency---I can now successfully run jax & pytorch reward learning within the same poetry environment!
Here are my current reproduction results (pytorch vs jax) for reward learning: https://wandb.ai/tliu/lm_human_preference_details
A few comments:
exp_name
use_tensorflow_adam=False
to check this.pmap
the reward model normalization step. For now, the normalization happens in one process. To do that I think we need to replacepretrained_model.generate
with a stateless function. I'll look into this, as it should be helpful for policy learning as well, where we need to sample responses.Let me know if you have any suggestions or feedback! If you have 8 spare GPUs, it would be great if you could run the jax version on them, and see if the results check out with previous pytorch results. Thanks!
Tianlin