vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.02k stars 575 forks source link

Parallel-envs-friendly ppo_continuous_action.py #348

Open vwxyzjn opened 1 year ago

vwxyzjn commented 1 year ago

Description

This PR modifies ppo_continuous_action.py to make it more parallel-envs-friendly. CC @kevinzakka.

The version of ppo_continuous_action.py in this PR is different from that in the master branch in the following ways:

  1. use a different set of hyperparameters that leverage more simulation environments (e.g., 64 parallel environments) https://github.com/vwxyzjn/cleanrl/blob/703cd3ba1214a15d2fc6ce9157f8c094d627c07b/cleanrl/ppo_continuous_action.py#L37-L71
  2. use gym.vector.AsyncVectorEnv in favor of gym.vector.SyncVectorEnv to speed up things more https://github.com/vwxyzjn/cleanrl/blob/703cd3ba1214a15d2fc6ce9157f8c094d627c07b/cleanrl/ppo_continuous_action.py#L163-L165
  3. apply the normalize wrappers at the parallel envs level instead of individual env level, meaning the running mean and std for the obs and returns will be calculated based on the whole batch of obs and rewards. In my experience, this is usually more preferable than maintaining the normalize wrappers at each sub-env. When N=1, it should not cause any performance difference https://github.com/vwxyzjn/cleanrl/blob/703cd3ba1214a15d2fc6ce9157f8c094d627c07b/cleanrl/ppo_continuous_action.py#L166-L170
    • one thing that would be worth trying is to remove the normalize wrappers — it should improve SPS. Or in the case of JAX, maybe re-writing and jitting the normalize wrappers will improve SPS as well.

I also added a JAX variant that reached the same level of performance

image

Types of changes

Checklist:

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Updated
cleanrl ✅ Ready (Inspect) Visit Preview Jan 13, 2023 at 2:25PM (UTC)
kevinzakka commented 1 year ago

Thank you @vwxyzjn! I'll give this a spin.

varadVaidya commented 3 months ago

hello. Thanks alot for implementing PPO in JAX in such a clean fashion. But, while reproducing the results, i am facing the following issue.

Traceback (most recent call last):
  File "/scratch/vaidya/mujoco_sims/gym_mujoco_drones/gym_mujoco_drones/cleanrl_jax_ppo.py", line 199, in <module>
    agent_state = TrainState.create(
  File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/flax/training/train_state.py", line 127, in create
    params['params'] if OVERWRITE_WITH_GRADIENT in params else params
TypeError: argument of type 'AgentParams' is not iterable
Exception ignored in: <function AsyncVectorEnv.__del__ at 0x7f6aa6d89630>
Traceback (most recent call last):
  File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/async_vector_env.py", line 549, in __del__
  File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/vector_env.py", line 272, in close
  File "/scratch/vaidya/miniconda3/envs/sbx-gpu/lib/python3.10/site-packages/gymnasium/vector/async_vector_env.py", line 465, in close_extras
AttributeError: 'NoneType' object has no attribute 'TimeoutError'

Since i am currently new to JAX, i am unable to debug the issue of AgentParams being not iterable on my own. I understand that this is a work in progress, but i would appreciate any pointers to solve this. Thanks