vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.02k stars 575 forks source link

Using jax scan for PPO + atari + envpool XLA #327

Closed 51616 closed 1 year ago

51616 commented 1 year ago

Description

Modifying the code to use jax.lax.scan for fast compile time and small speed improvement.

Types of changes

Checklist:

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Dec 4, 2022 at 10:32AM (UTC)
51616 commented 1 year ago

This code seems to have similar performance (reward-wise) vs the original version. If some of the code is not consistent with other algorithms or hard to understand, I'm happy to modify those parts. Please let me know if you want a more concrete performance comparison. I could use wandb to track the results.

vwxyzjn commented 1 year ago

This code seems to have similar performance (reward-wise) vs the original version. If some of the code is not consistent with other algorithms or hard to understand, I'm happy to modify those parts.

Yes, would you mind running the original version and the new version and create a report for their learning curves both using steps as the x axis and relative time as the x axis?

vwxyzjn commented 1 year ago

Two quick comments. First, would you mind creating a different branch with a name other than master? It will make merging easier down the line. Second, I did a quick test, and the performance looks the same — great job @51616! However, note the policy_loss looks suspiciously different. This is usually a sign that some subtle bug was introduced, which could have performance implications in other games. Would you mind looking into it?

image
51616 commented 1 year ago

@vwxyzjn. The loss in the new file is average over minibatches and epochs whereas the original's comes from the last minimatch if I'm not mistaken. I think this is where the discrepancy is. I'll look into that later today

51616 commented 1 year ago

The code now logs the loss metrics of the last minibatch. The losses now looks the same as the original version. Green is the original, blue is jax scan with last minibatch only and grey is the initial code that logs average metrics.

image

This is the learning curve vs relative time. Blue and green use the same num-envs. Grey uses 2x more so it's a bit faster for collecting 10M samples. image

First, would you mind creating a different branch with a name other than master? It will make merging easier down the line.

Should I make a new pull request for this?

vwxyzjn commented 1 year ago

Should I make a new pull request for this?

Yes please.

Blue and green use the same num-envs. Grey uses 2x more so it's a bit faster for collecting 10M samples.

To compare the original and the new lax.scan we should probably use the same num-envs. In my test, I found both original and lax.scan spent ~40 minutes end-to-end. This suggests the original version should have an even higher SPS_update given its slow compilation time.

image
51616 commented 1 year ago

@vwxyzjn Sorry for the confusion. I'll make a new request with just logging the last minibatch from a new branch.