vwxyzjn / lm-human-preference-details

RLHF implementation details of OAI's 2019 codebase
MIT License
152 stars 7 forks source link

Jax policy learning #18

Closed liutianlin0121 closed 1 year ago

liutianlin0121 commented 1 year ago

Hi Costa,

This PR contains the jax implementation of policy learning. You can find the comparison between jax and pytorch runs here. Some metrics are in the screenshots below.

Screenshot from 2023-09-05 11-34-55 Screenshot from 2023-09-05 11-34-13 Screenshot from 2023-09-05 11-33-30

Some comments:

Regarding the metrics, I follow the existing ones in policy_learning_accelerate.py. The current metrics are comprehensive and helpful, but there are repetitions that I think might complicate interpretation. Specifically, we have metrics that are aggregated at the batch level, microbatch level, and both levels. I've summarized the PPO-related metrics below, excluding those related to the rollout since they don't exhibit the batch vs. microbatch confusion.

Microbatch-level metric Batch-level counterpart
ppo/policy/approxkl ppo/policy/approxkl_avg
ppo/policy/clipfrac ppo/policy/clipfrac_avg
ppo/policy/entropy ppo/policy/entropy_avg
ppo/loss/policy ppo/loss/policy_avg
ppo/loss/value ppo/loss/value_avg
ppo/val/clipfrac ppo/val/clipfrac_avg
ppo/val/error N/A
ppo/val/ratio N/A
ppo/val/ratio_hist N/A
ppo/loss/total N/A
ppo/val/vpred N/A

Perhaps we can simplify these metrics by only keeping the batch-level statistics? I think that the microbatch-level statistics tend to be noisy and have high variance, making cross-run comparisons less reliable. Moreover, their batch-level counterpart already contain the same information (up to an average). For those microbatch level statistics that are currently not recorded in batch level (the N/As in the table), we can add them there.

Additionally, I'm unsure about the meaning of two metrics, ppo/val/ratio_var and ppo/val/advantage_var. It seems they describe variance in the ratio or advantage across devices, not samples? Not sure if that's intentional. Indeed it looks like the var() is evaluated after accelerator.gather of per-device means, like

writer.add_scalar(
    "ppo/val/advantage_var",
    accelerator.gather(advantages.mean()).var().item(),
    update,
)

In practice, both metrics have very small values, with one less than 1e-5 and the other less than 1e-14. I didn't incorporate ppo/val/ratio_var and ppo/val/advantage_var in the jax implementation, but I can add them back if they are needed.

Happy to hear your feedback. Thank you! 🤗

Tianlin

liutianlin0121 commented 1 year ago

Thanks for the reviews and suggestions! I'll convert the PR into draft and reopen it when it is ready for review again.

liutianlin0121 commented 1 year ago

Hey Costa @vwxyzjn ,

Thanks for your feedback again. I updated the PR. By pmapping and scanning the update function (rollout + epochs of PPO steps), the GPU utility is indeed higher than before. However, the wall clock speed remains slightly slower than the pytorch version. Could you review the code again to identify any possible optimization? Thank you!

One issue I've encountered is implementing right_padding_to_left_padding to make it jittable. The naive ways that I have tried incur dynamic shape problems. To make right_padding_to_left_padding jittable, I resort to a method that sorts the mask, although it may not be very efficient:

def right_padding_to_left_padding(tokens, pad_id):
    def pad_row(row):
        mask = 1 - (row == pad_id)  # 1 if not pad_id, 0 if pad_id
        return row[
            jnp.argsort(mask)
        ]  # uses the fact that jnp.argsort is stable by default

    return jax.vmap(pad_row)(tokens)

Do you see simpler alternatives for implementing to achieve a jittable right_padding_to_left_padding?

I'll provide a test case for GAE soon Just added a test case for GAE :-)

For now I'm mostly intrigued by the implementation's speed.... please share any suggestions you might have!

Tianlin

liutianlin0121 commented 1 year ago

Thanks for the suggestions! I made the edits accordingly and updated the PR.

The speed is however still comparable to the pytorch version. I'll try to dig deeper...

vwxyzjn commented 1 year ago

There is also profiling :) https://jax.readthedocs.io/en/latest/profiling.html Let me also take a look :)

vwxyzjn commented 1 year ago

Ok I took a quick look. Some suggestions for diagnosing. Current GPU utilization looks like

https://github.com/vwxyzjn/lm-human-preference-details/assets/5555347/5a35268f-39dc-4e08-9323-343b47646a9a

So, there is a gap somehow.

I try to isolate the cause by doing something like https://gist.github.com/vwxyzjn/2687ecad96f3b32df539bbbe223e7f42. Basically I commented out metrics calculation and do not try to shard data. Then the GPU utilization is like near 100%. So the issue is either with 1) metrics calculation or 2) dataloading, or 3)common_utils.shard.

    data = next(iter_dataloader)
    input_ids = common_utils.shard(data["input_ids"].numpy())
    for update in range(1, args.ppo.num_updates + 1):
        global_step += args.ppo.batch_size

        policy_state, rollout_stats, rl_stats, samples_to_print = p_train_update(
            policy_state=policy_state,
            input_ids=input_ids,
            rl_stats=rl_stats,
            kl_ctl_value=np.array([kl_ctl.value] * input_ids.shape[0]),
        )

https://github.com/vwxyzjn/lm-human-preference-details/assets/5555347/c78de6ab-1aa0-47f0-9576-b3ccc171c792

liutianlin0121 commented 1 year ago

Thanks! That's very informative. In my 2 GPU setting, though, the utility of both are very close to 100% even with the current implementation:

Screenshot from 2023-09-07 16-47-44

liutianlin0121 commented 1 year ago

Added numpy_collate to avoid the conversion from pytorch tensor to jnp.ndarrays in data loading. The effect is small on my end, though.

vwxyzjn commented 1 year ago

If it's already 100% it's pretty good.

The speed is however still comparable to the pytorch version. I'll try to dig deeper...

Which experiments are you comparing, mine or yours? To do a fair comparison, we need to use the same hardware :)

liutianlin0121 commented 1 year ago

Which experiments are you comparing, mine or yours? To do a fair comparison, we need to use the same hardware :)

Yes, I ran both jax and pytorch implementations on my 2 GPU desktop with the same setting ( ppo.gradient-accumulation-steps=4, ppo.lr=1e-6, and otherwise default config).

So there seems to be two separate problems. (1) When running on 2 GPUs, the GPU utilization is high, but despite this, the jax implementation does not outperform pytorch in speed. (2) When using 8 GPUs, the GPU utilization is low in the jax implementation.

For (1), I think the sorting that involved in the jittable right_padding_to_left_padding can indeed cause a slow down. I'll try to do a more proper speed comparison soon. I don't have an idea for (2).

vwxyzjn commented 1 year ago

Which experiments are you comparing, mine or yours? To do a fair comparison, we need to use the same hardware :)

Yes, I ran both jax and pytorch implementations on my 2 GPU desktop with the same setting ( ppo.gradient-accumulation-steps=4, ppo.lr=1e-6, and otherwise default config).

So there seems to be two separate problems. (1) When running on 2 GPUs, the GPU utilization is high, but despite this, the jax implementation does not outperform pytorch in speed. (2) When using 8 GPUs, the GPU utilization is low in the jax implementation.

For (1), I think the sorting that involved in the jittable right_padding_to_left_padding can indeed cause a slow down. I'll try to do a more proper speed comparison soon. I don't have an idea for (2).

Maybe another related thing is in Jax there is no flash attention whereas pytorch does have it.

liutianlin0121 commented 1 year ago

Hi Costa @vwxyzjn. I am wondering if we can merge the PR. Although the speed of the jax policy learning is only comparable with pytorch, the metrics match closely. Especially, the policy_avg metrics now go to zero in both implementations (figure below); I think the earlier jax version didn't converge to zero because of a glitch in the learning rate annealing.

image

But of course, please feel free to keep the PR open if you think this way is better, and feel free to suggest further edits before the merge :-)

liutianlin0121 commented 1 year ago

Awesome, thanks! Just a note for further benchmarking: as discussed, all metrics in jax policy learning are averaged on the batch level (as opposed to the microbatch level for a subset of metrics in pytorch policy learning). For this reason, some metrics from jax may be smoother than their current pytorch counterparts. Below is an example.

Screenshot 2023-09-12 at 21 53 33