vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.26k stars 602 forks source link

Proof-of-concept: Faster PyTorch #306

Closed DavidSlayback closed 1 year ago

DavidSlayback commented 1 year ago

Description

While I share a lot of the excitement about the speed boosts that JAX provides, I've found that additional performance can be extracted from PyTorch. This commit is mostly a proof-of-concept, using PPO as a test cast, to show how various levels of optimization can be applied to improve PyTorch speed. There are a few more things I can add, and I still need to do runs with larger networks (where the speed improvement is greater), but this is a start to discuss how much the codebase can be changed to improve speed without losing readability.

Types of changes

Checklist:

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

Below is the results of various levels of optimization applied to CartPole. A larger set of environments (using the benchmark utility on the same hardware for all runs) can be found in this wandb report

image

L0 is the baseline L1 uses TorchScript on the Sequential modules and optimizer.zero_grad(True) L2 additionally uses TorchScript for the advantage and normalization functions, as well as in-place activations L3 additionally uses TorchScript for the probability computations and action sampling L4 (not shown, need new runs) uses TorchScript for the full PPO loss function

I'm seeing large benefits from L3 in particular, and it's something I could potentially apply to any of the PyTorch algorithms, but it's also the first level of optimization where the readability really changes. Interested in starting a discussion over what is/is not worth it!

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Oct 31, 2022 at 2:08PM (UTC)
vwxyzjn commented 1 year ago

Thank you @DavidSlayback for these great prototyping efforts. It's really nice that you JITed the action sampling process by implementing our own sampling function! Some thoughts:

  1. If the speed-up is considerable (e.g., 30% overall training time reduction), then it's worth including the implementation.
    • For the sake of maintenance, maybe we can add the result of this PR as a variant like ppo_atari_jit.py. I am a little hesitant to use the JITed implementations in place of our regular implementations because the regular ones are still easier to debug.
    • Once you feel comfortable with the prototypes, would you mind giving the best of them a try in Atari games? I have plenty of data points to compare (see report here).
  2. Would it be possible to jit the linear rate annealing? Previously with JAX JITing the linear rate annealing improved the overall speed in MuJoCo by 2x (https://github.com/vwxyzjn/cleanrl/pull/217#issuecomment-1166789791).
  3. New techs are coming, such as https://github.com/pytorch/torchdynamo and https://github.com/metaopt/torchopt. How are they going to affect the optimization techniques?
DavidSlayback commented 1 year ago

No problem!

  1. Yeah, I think it makes sense to do them as separate files. Once I determine the sweet spot of optimizations (performance without becoming unreadable), I was planning on doing a comparable JIT version for each algorithm (particularly interested in recurrent layers). I can definitely try them in some Atari games, I just needed something that would run quickly and demonstrate that the episodic performance matches.
  2. Unfortunately, I don't think that's doable with the base torch.optim optimizers. Like the torch.distributions classes, they don't play nicely with JIT...I can use a built-in LRScheduler, but they seem to do what you already do under the hood. I could JIT the annealing function, but I still have to set the underlying learning rate.
  3. I hadn't seen torchopt before, I'm somewhat familiar with functorch and torchdynamo but was waiting for them to become more mature. It looks like functorch and torchdynamo have already moved into PyTorch, so I'll see if I can use those to wrap the optimization

So next steps: 1) Test torchdynamo/torchopt/functorch techniques, settle on best "return-on-optimization" 2) Apply the chosen techniques to ppo_atari and ppo_atari_lstm

vwxyzjn commented 1 year ago

Hey @DavidSlayback, thanks for doing the investigation. I did a quick prototype on JAX to see the speed difference .

Not using a GPU

image

https://wandb.ai/costa-huang/cleanRL/reports/Pytorch-JIT-vs-JAX-JIT--VmlldzozMjY4MjQ2

Jax source code https://wandb.ai/costa-huang/cleanRL/runs/j5k5vdl7/code?workspace=user-costa-huang

The SPS of ppo_jax.py can be further improved by removing the compilation time via #328. The real SPS is about 15k

Using a GPU

image

The real SPS is about 8k

vwxyzjn commented 1 year ago

Hey @DavidSlayback, thanks again for this PR. I was thinking a more suited place for these experiments is probably a separate repository, and we are happy to refer it in our docs to more advanced users :)

Closing this PR now.