Added value loss clipping to Jax ppo. I noticed it was there in the pytorch implementation and not in the Jax implementation, and it made quite a big performance difference on the environments I tested.
[ ] I have performed RLops with python -m openrlbenchmark.rlops.
For new feature or bug fix:
[ ] I have used the RLops utility to understand the performance impact of the changes and confirmed there is no regression.
For new algorithm:
[ ] I have created a table comparing my results against those from reputable sources (i.e., the original paper or other reference implementation).
[ ] I have added the learning curves generated by the python -m openrlbenchmark.rlops utility to the documentation.
[ ] I have added links to the tracked experiments in W&B, generated by python -m openrlbenchmark.rlops ....your_args... --report, to the documentation.
Description
Added value loss clipping to Jax ppo. I noticed it was there in the pytorch implementation and not in the Jax implementation, and it made quite a big performance difference on the environments I tested.
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you need to run benchmark experiments for a performance-impacting changes:
--capture-video
.python -m openrlbenchmark.rlops
.python -m openrlbenchmark.rlops
utility to the documentation.python -m openrlbenchmark.rlops ....your_args... --report
, to the documentation.