Closed kinalmehta closed 1 year ago
The latest updates on your projects. Learn more about Vercel for Git ↗︎
Name | Status | Preview | Updated |
---|---|---|---|
cleanrl | ✅ Ready (Inspect) | Visit Preview | Dec 30, 2022 at 5:23PM (UTC) |
Results on classical gym environments can be checked here. https://wandb.ai/kinalmehta/jax-cleanrl/reports/C51-JAX-vs-Pytorch-on-Classical-Gym-Environments--VmlldzoyNDQ3OTk5
We see a speed-up of about ~30% in the JAX version compared to Pytorch.
Here is the benchmark report on atari environments https://wandb.ai/kinalmehta/jax-cleanrl/reports/C51-JAX-vs-Pytorch-on-Atari-Environments--VmlldzoyNjkyNDY0
Important observations:
Need to look into more detail about the differences between PyTorch and JAX implementations to fix the above mentioned issues.
How does it compare to Dopamine's version?
How does it compare to Dopamine's version? I haven't checked Dopamine yet. I will have a look and update here, though it might take some time.
FYI dopamine has a benchmark, but its x-axis is not the environment steps... Any clue on how we can compare those results? @joaogui1
After months of procrastination and debugging various aspects, I finally stumbled upon the cause of performance degradation. The incorrect epsilon value caused this performance degradation. I missed this detail and used the default value $10^{-8}$ from optax. However, the C51-PyTorch version uses ${0.01}/{batch\_size}$. Hoowever I couldn't find any motivation for using this value.
Reading up more on this led to the conclusion that this is a common issue even in NLP and CV as well. More about this hyperparameter can be read here.
I have updated the plots of classical gym environments (CartPole, Acrobot, MountainCar) by benchmarking on CPU. We see significant speed-up compared to pytorch version on CPU.
Based on the beamrider plot shared above, the below table summarizes the final score comparison | implementation | score |
---|---|---|
dopamine | 5000-7000 | |
cleanrl-pytorch | ~9500 | |
cleanrl-jax-old | ~2500 | |
cleanrl-jax-fixed | ~9500 |
The updated plots are available on the above links itself. The PR looks good to be mearged once the documentation is updated. Anything else I am missing here @vwxyzjn?
The results look incredible. Great job @kinalmehta. Thanks for chasing down the cause for the issue. The code also look great to me. Feel free to start adding documentation. You should also move the experiments to the openrlbenchmark/cleanrl
namespace.
I've added the documentation, and now I believe this PR is ready for the final review.
Description
JAX implementation for C51 Implementation for #221
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithms or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.width=500
andheight=300
).