Closed vwxyzjn closed 1 year ago
Currently, we can already near perfect learning curves matching with the pytorch implementation.
It will be helpful if we can replicate this success using JAX. Here is a helpful snippet for Jax policy training. https://gist.github.com/vwxyzjn/7005a81ba39deb3bc8043041bd715be1
It's probably easier to translate https://github.com/vwxyzjn/lm-human-preference-details/blob/main/lm_human_preference_details/train_reward_accelerate.py to jax. Trying to match the reward model learning curves.
https://github.com/vwxyzjn/lm-human-preference-details/blob/main/lm_human_preference_details/train_reward_accelerate.py
pip install openrlbenchmark==0.2.1a4 python -m openrlbenchmark.rlops_multi_metrics \ --filters '?we=openrlbenchmark&wpn=lm-human-preferences&xaxis=_step&ceik=task_id&cen=task.value.policy.initial_model&metrics=train_reward/minibatch/error' \ '124M' \ --filters '?we=openrlbenchmark&wpn=lm_human_preference_details&xaxis=_step&ceik=label_dataset&cen=exp_name&metrics=train/loss' \ 'train_reward_accelerate?tag=v0.1.0-58-g4f42012&tag=tf_adam&tag=gpt2&cl=tf_adam,gpt2' \ --env-ids sentiment descriptiveness \ --env-ids sentiment/offline_5k.json descriptiveness/offline_5k.json \ --no-check-empty-runs \ --pc.ncols 2 \ --pc.ncols-legend 1 \ --pc.max_steps 250 \ --output-filename static/0compare \ --scan-history
Currently, we can already near perfect learning curves matching with the pytorch implementation.
It will be helpful if we can replicate this success using JAX. Here is a helpful snippet for Jax policy training. https://gist.github.com/vwxyzjn/7005a81ba39deb3bc8043041bd715be1
It's probably easier to translate
https://github.com/vwxyzjn/lm-human-preference-details/blob/main/lm_human_preference_details/train_reward_accelerate.py
to jax. Trying to match the reward model learning curves.