Closed liutianlin0121 closed 1 year ago
Thanks for the reviews and suggestions! I'll convert the PR into draft and reopen it when it is ready for review again.
Hey Costa @vwxyzjn ,
Thanks for your feedback again. I updated the PR. By pmapping and scanning the update function (rollout + epochs of PPO steps), the GPU utility is indeed higher than before. However, the wall clock speed remains slightly slower than the pytorch version. Could you review the code again to identify any possible optimization? Thank you!
One issue I've encountered is implementing right_padding_to_left_padding
to make it jittable. The naive ways that I have tried incur dynamic shape problems. To make right_padding_to_left_padding
jittable, I resort to a method that sorts the mask, although it may not be very efficient:
def right_padding_to_left_padding(tokens, pad_id):
def pad_row(row):
mask = 1 - (row == pad_id) # 1 if not pad_id, 0 if pad_id
return row[
jnp.argsort(mask)
] # uses the fact that jnp.argsort is stable by default
return jax.vmap(pad_row)(tokens)
Do you see simpler alternatives for implementing to achieve a jittable right_padding_to_left_padding
?
I'll provide a test case for GAE soon Just added a test case for GAE :-)
For now I'm mostly intrigued by the implementation's speed.... please share any suggestions you might have!
Tianlin
Thanks for the suggestions! I made the edits accordingly and updated the PR.
The speed is however still comparable to the pytorch version. I'll try to dig deeper...
There is also profiling :) https://jax.readthedocs.io/en/latest/profiling.html Let me also take a look :)
Ok I took a quick look. Some suggestions for diagnosing. Current GPU utilization looks like
So, there is a gap somehow.
I try to isolate the cause by doing something like https://gist.github.com/vwxyzjn/2687ecad96f3b32df539bbbe223e7f42. Basically I commented out metrics calculation and do not try to shard data. Then the GPU utilization is like near 100%. So the issue is either with 1) metrics calculation or 2) dataloading, or 3)common_utils.shard
.
data = next(iter_dataloader)
input_ids = common_utils.shard(data["input_ids"].numpy())
for update in range(1, args.ppo.num_updates + 1):
global_step += args.ppo.batch_size
policy_state, rollout_stats, rl_stats, samples_to_print = p_train_update(
policy_state=policy_state,
input_ids=input_ids,
rl_stats=rl_stats,
kl_ctl_value=np.array([kl_ctl.value] * input_ids.shape[0]),
)
Thanks! That's very informative. In my 2 GPU setting, though, the utility of both are very close to 100% even with the current implementation:
Added numpy_collate
to avoid the conversion from pytorch tensor to jnp.ndarrays in data loading. The effect is small on my end, though.
If it's already 100% it's pretty good.
The speed is however still comparable to the pytorch version. I'll try to dig deeper...
Which experiments are you comparing, mine or yours? To do a fair comparison, we need to use the same hardware :)
Which experiments are you comparing, mine or yours? To do a fair comparison, we need to use the same hardware :)
Yes, I ran both jax and pytorch implementations on my 2 GPU desktop with the same setting ( ppo.gradient-accumulation-steps=4
, ppo.lr=1e-6
, and otherwise default config).
So there seems to be two separate problems. (1) When running on 2 GPUs, the GPU utilization is high, but despite this, the jax implementation does not outperform pytorch in speed. (2) When using 8 GPUs, the GPU utilization is low in the jax implementation.
For (1), I think the sorting that involved in the jittable right_padding_to_left_padding
can indeed cause a slow down. I'll try to do a more proper speed comparison soon. I don't have an idea for (2).
Which experiments are you comparing, mine or yours? To do a fair comparison, we need to use the same hardware :)
Yes, I ran both jax and pytorch implementations on my 2 GPU desktop with the same setting (
ppo.gradient-accumulation-steps=4
,ppo.lr=1e-6
, and otherwise default config).So there seems to be two separate problems. (1) When running on 2 GPUs, the GPU utilization is high, but despite this, the jax implementation does not outperform pytorch in speed. (2) When using 8 GPUs, the GPU utilization is low in the jax implementation.
For (1), I think the sorting that involved in the jittable
right_padding_to_left_padding
can indeed cause a slow down. I'll try to do a more proper speed comparison soon. I don't have an idea for (2).
Maybe another related thing is in Jax there is no flash attention whereas pytorch does have it.
Hi Costa @vwxyzjn. I am wondering if we can merge the PR. Although the speed of the jax policy learning is only comparable with pytorch, the metrics match closely. Especially, the policy_avg
metrics now go to zero in both implementations (figure below); I think the earlier jax version didn't converge to zero because of a glitch in the learning rate annealing.
But of course, please feel free to keep the PR open if you think this way is better, and feel free to suggest further edits before the merge :-)
Awesome, thanks! Just a note for further benchmarking: as discussed, all metrics in jax policy learning are averaged on the batch level (as opposed to the microbatch level for a subset of metrics in pytorch policy learning). For this reason, some metrics from jax may be smoother than their current pytorch counterparts. Below is an example.
Hi Costa,
This PR contains the jax implementation of policy learning. You can find the comparison between jax and pytorch runs here. Some metrics are in the screenshots below.
Some comments:
Due to memory constraints on my 2-GPU desktop, for now I only ran jax and pytorch policy learning using specific configurations
ppo.gradient-accumulation-steps=4
andppo.lr=1e-6
.I opted for a lower learning rate (1e-6) compared to the default (1e-5) to counteract the small-batch effect (as the batch size is correlated with the number of GPUs) that causes instability. Furthermore, I only ran both implementations with 1 random seed. Since each run takes about 9 hours, running with multiple seeds is difficult for me 😅. If you have access to 8 spare GPUs, running with standard configs and multiple seeds would be immensely helpful! I hope both versions are still comparable in the 8-GPU setting.The jax policy learning is slightly slower than the pytorch version. To make the jax version faster, I think we need to
pmap
a greater chunk of the code for policy rollout. But this will introduce more abstractions. If you have any tips for accelerating the jax implementation, please let me know!Regarding the metrics, I follow the existing ones in
policy_learning_accelerate.py
. The current metrics are comprehensive and helpful, but there are repetitions that I think might complicate interpretation. Specifically, we have metrics that are aggregated at the batch level, microbatch level, and both levels. I've summarized the PPO-related metrics below, excluding those related to the rollout since they don't exhibit the batch vs. microbatch confusion.ppo/policy/approxkl
ppo/policy/approxkl_avg
ppo/policy/clipfrac
ppo/policy/clipfrac_avg
ppo/policy/entropy
ppo/policy/entropy_avg
ppo/loss/policy
ppo/loss/policy_avg
ppo/loss/value
ppo/loss/value_avg
ppo/val/clipfrac
ppo/val/clipfrac_avg
ppo/val/error
ppo/val/ratio
ppo/val/ratio_hist
ppo/loss/total
ppo/val/vpred
Perhaps we can simplify these metrics by only keeping the batch-level statistics? I think that the microbatch-level statistics tend to be noisy and have high variance, making cross-run comparisons less reliable. Moreover, their batch-level counterpart already contain the same information (up to an average). For those microbatch level statistics that are currently not recorded in batch level (the N/As in the table), we can add them there.
Additionally, I'm unsure about the meaning of two metrics,
ppo/val/ratio_var
andppo/val/advantage_var
. It seems they describe variance in the ratio or advantage across devices, not samples? Not sure if that's intentional. Indeed it looks like thevar()
is evaluated afteraccelerator.gather
of per-device means, likeIn practice, both metrics have very small values, with one less than 1e-5 and the other less than 1e-14. I didn't incorporate
ppo/val/ratio_var
andppo/val/advantage_var
in the jax implementation, but I can add them back if they are needed.Happy to hear your feedback. Thank you! 🤗
Tianlin