vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.54k stars 631 forks source link

Jax c51 contrib #224

Closed kinalmehta closed 1 year ago

kinalmehta commented 2 years ago

Description

JAX implementation for C51 Implementation for #221

Types of changes

Checklist:

If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 2 years ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Dec 30, 2022 at 5:23PM (UTC)
kinalmehta commented 2 years ago

Results on classical gym environments can be checked here. https://wandb.ai/kinalmehta/jax-cleanrl/reports/C51-JAX-vs-Pytorch-on-Classical-Gym-Environments--VmlldzoyNDQ3OTk5

We see a speed-up of about ~30% in the JAX version compared to Pytorch.

kinalmehta commented 2 years ago

Here is the benchmark report on atari environments https://wandb.ai/kinalmehta/jax-cleanrl/reports/C51-JAX-vs-Pytorch-on-Atari-Environments--VmlldzoyNjkyNDY0

Important observations:

Need to look into more detail about the differences between PyTorch and JAX implementations to fix the above mentioned issues.

joaogui1 commented 2 years ago

How does it compare to Dopamine's version?

kinalmehta commented 2 years ago

How does it compare to Dopamine's version? I haven't checked Dopamine yet. I will have a look and update here, though it might take some time.

vwxyzjn commented 2 years ago

FYI dopamine has a benchmark, but its x-axis is not the environment steps... Any clue on how we can compare those results? @joaogui1

image
kinalmehta commented 1 year ago

Atari Fixed

After months of procrastination and debugging various aspects, I finally stumbled upon the cause of performance degradation. The incorrect epsilon value caused this performance degradation. I missed this detail and used the default value $10^{-8}$ from optax. However, the C51-PyTorch version uses ${0.01}/{batch\_size}$. Hoowever I couldn't find any motivation for using this value.

Reading up more on this led to the conclusion that this is a common issue even in NLP and CV as well. More about this hyperparameter can be read here.

Benchmarking classical envs on CPU

I have updated the plots of classical gym environments (CartPole, Acrobot, MountainCar) by benchmarking on CPU. We see significant speed-up compared to pytorch version on CPU.

Comparison with dopamine

Based on the beamrider plot shared above, the below table summarizes the final score comparison implementation score
dopamine 5000-7000
cleanrl-pytorch ~9500
cleanrl-jax-old ~2500
cleanrl-jax-fixed ~9500

Reports link

Conclusion

The updated plots are available on the above links itself. The PR looks good to be mearged once the documentation is updated. Anything else I am missing here @vwxyzjn?

vwxyzjn commented 1 year ago

The results look incredible. Great job @kinalmehta. Thanks for chasing down the cause for the issue. The code also look great to me. Feel free to start adding documentation. You should also move the experiments to the openrlbenchmark/cleanrl namespace.

kinalmehta commented 1 year ago

I've added the documentation, and now I believe this PR is ready for the final review.