vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.02k stars 575 forks source link

add tianshou-like JAX+PPO+Mujoco #355

Open quangr opened 1 year ago

quangr commented 1 year ago

Description

Add tianshou-like JAX+PPO+Mujoco code, which is tested in Hopper-v3 and HalfCheetah-v3.

11 seed test

Hopper-v3 (Tianshou 1M:2609.3+-700.8 ; 3M:3127.7+-413.0) my result: Hopper-v3

HalfCheetah-v3 (Tianshou 1M:5783.9+-1244.0 ; 3M:7337.4+-1508.2) my result: HalfCheetah-v3

This implementation uses a customized EnvWrapper class to wrap environment. Different from traditional Gym-type wrap which has step and reset method. EnvWrapper requires three methods recv,send and reset, these methods need to be pure functions in order to be transformed in jax. The recv method will modify what env received after an action step, and the send method will modify the action send to env.

Types of changes

Checklist:

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add your feedback Feb 5, 2023 at 5:55AM (UTC)
quangr commented 1 year ago

你好@quangr,感谢您的贡献!对于很多人来说,能够使用 JAX+PPO+MuJoCo+EnvPool 将改变游戏规则!此 PR 也将使#217变得不必要。

一些评论和想法:

  • 您介意分享您的 wandb 用户名以便我将您添加到该openrlbenchmark实体吗?如果您可以在那里贡献跟踪实验,那就太好了,我们可以使用我们的 CLI 实用程序 ( https://github.com/openrlbenchmark/openrlbenchmark ) 来绘制图表。

Thank @51616 and @vwxyzjn for code reviewing❤️! My wandb username is quangr. I will check these comment and improve on my code soon.

quangr commented 1 year ago

I have submit new commits for most of the comments, and here are my answers to some other comments. If there is something missing, please help let me know

Why values[1:] * (1.0 - dones[1:])? Maybe this should be handled within the compute_gae_once

I'm trying to mask the done value because tianshou do so https://github.com/thu-ml/tianshou/blob/774d3d8e833a1b1c1ed320e0971ab125161f4264/tianshou/policy/base.py#L288.

you're right i'm putting it to compute_gae_once function

API change

The xla api provide by envpool is not a pure function. The handle passing to send function is just a fat pointer point to envpool class.

When we keep all state inside a handle tuple, if we reset the environment, the pointer remains unchange, Other parts (like new statistics state) also requries a reset state. So I think there must be a change towards envpool API.

In order to have a less confusing API, maybe we can remove handle from return values of envs.xla(). I can't think of way to make things consistent with envpool for now.

Observation Normalization and gym or gymnasium

https://github.com/vwxyzjn/cleanrl/blob/c79d66f45c8bd0f567ea269a8ceedd157ad69b87/cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py#L313-L317

There is a Observation Normalization, it was implemented as wrapper. And I think the Observation Normalization is mandatory to acheving high score in mujoco env.

And in this Observation Normalization Wrapper I actually turn the gym api into gymnasium api:

https://github.com/vwxyzjn/cleanrl/blob/c79d66f45c8bd0f567ea269a8ceedd157ad69b87/cleanrl/ppo_continuous_action_envpool_xla_jax_scan.py#L177-L187

This is because when I writing this code, I use envpool lastest(0.8.1) version, and it use gymnasium api.

reward normalization

This reward normalization is bizarre to me too, but this is how tianshou implemente it, and it really works.

HalfCheetah-v3

W B Chart 2023_1_31 14_47_35

Hooper-v3

W B Chart 2023_1_31 14_48_06

quangr commented 1 year ago

I have run experiment for Ant-v4 HalfCheetah-v4 Hopper-v4 Walker2d-v4 Swimmer-v4 Humanoid-v4 Reacher-v4 InvertedPendulum-v4 InvertedDoublePendulum-v4. Here is the report https://wandb.ai/openrlbenchmark/cleanrl/reports/MuJoCo-jax-EnvPool--VmlldzozNDczNDkz.

comparing the result to tianshou

I notice that tianshou use 10 envs to evaluate performance from reset state every epoch as their benchmark, I wonder if it's a problem for comparing.

comparing with ppo_continuous_action_8M

better : Ant-v4 HalfCheetah-v4 similar :Hopper-v4 Walker2d-v4 Humanoid-v4 Reacher-v4 worse :Swimmer-v4 In tianshou benchmark, their parameter also act poorly in Swimmer. So it's a sensonable result.

as for InvertedDoublePendulum-v4 InvertedPendulum-v4, every agents in my version reach 1000 score, which not happens in ppo_continuous_action_8M. But it start to decline afterward, and in tainshou's training data we can observe same decline curve: https://drive.google.com/drive/folders/1tQvgmsBbuLPNU3qo5thTBi03QzGXygXf https://drive.google.com/drive/folders/1ns2cGnAn_39wqCItmhDZIxihLi8-DBei

vwxyzjn commented 1 year ago

Thanks for running the results! They look great. Feel free to click resolve conversation as you resolve the PR comments and let me know when it's ready for another review.

Meanwhile, you might find the following tool handy

pip install openrlbenchmark
python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_8M?tag=v1.0.0-13-gcbd83f6'  \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_continuous_action_envpool_xla_jax_scan?tag=v1.0.0-jax-ca-be3113b'   \
    --env-ids Ant-v4 InvertedDoublePendulum-v4 Reacher-v4  Hopper-v4 HalfCheetah-v4 Swimmer-v4  Humanoid-v4 InvertedPendulum-v4 Walker2d-v4 \
    --check-empty-runs False \
    --ncols 4 \
    --ncols-legend 3 \
    --output-filename ppo_continuous_action_envpool_xla_jax_scan \
    --scan-history --report

Which generates

ppo_continuous_action_envpool_xla_jax_scan ppo_continuous_action_envpool_xla_jax_scan-time

and the wandb report:

https://wandb.ai/costa-huang/cleanRL/reports/Regression-Report-ppo_continuous_action_envpool_xla_jax_scan--VmlldzozNDgzMzM4

Couple of notes:

See

python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_8M?tag=v1.0.0-13-gcbd83f6'  \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_continuous_action_envpool_xla_jax_scan?tag=v1.0.0-jax-ca-be3113b'   \
    --filters '?we=openrlbenchmark&wpn=envpool-cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_envpool' \
    --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4  \
    --check-empty-runs False \
    --ncols 4 \
    --ncols-legend 3 \
    --output-filename ppo_continuous_action_envpool_xla_jax_scan \
    --scan-history

ppo_continuous_action_envpool_xla_jax_scan

The hyperparams used is in https://github.com/vwxyzjn/envpool-cleanrl/blob/880552a168c08af334b5e5d8868bfbe5ea881445/ppo_continuous_action_envpool.py#L40-L75, with no obs nor reward normalization.

quangr commented 1 year ago

Thanks for running the results! They look great. Feel free to click resolve conversation as you resolve the PR comments and let me know when it's ready for another review.

Meanwhile, you might find the following tool handy

pip install openrlbenchmark
python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_8M?tag=v1.0.0-13-gcbd83f6'  \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_continuous_action_envpool_xla_jax_scan?tag=v1.0.0-jax-ca-be3113b'   \
    --env-ids Ant-v4 InvertedDoublePendulum-v4 Reacher-v4  Hopper-v4 HalfCheetah-v4 Swimmer-v4  Humanoid-v4 InvertedPendulum-v4 Walker2d-v4 \
    --check-empty-runs False \
    --ncols 4 \
    --ncols-legend 3 \
    --output-filename ppo_continuous_action_envpool_xla_jax_scan \
    --scan-history --report

Which generates

ppo_continuous_action_envpool_xla_jax_scan ppo_continuous_action_envpool_xla_jax_scan-time

and the wandb report:

https://wandb.ai/costa-huang/cleanRL/reports/Regression-Report-ppo_continuous_action_envpool_xla_jax_scan--VmlldzozNDgzMzM4

Couple of notes:

  • Would you mind running the experiments for HumanoidStandup-v4, Pusher-v4 as well?
  • FWIW, it is possible to get high scores in Humanoid-v4 as well.

See

python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_8M?tag=v1.0.0-13-gcbd83f6'  \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_continuous_action_envpool_xla_jax_scan?tag=v1.0.0-jax-ca-be3113b'   \
    --filters '?we=openrlbenchmark&wpn=envpool-cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' 'ppo_continuous_action_envpool' \
    --env-ids HalfCheetah-v4 Walker2d-v4 Hopper-v4 InvertedPendulum-v4 Humanoid-v4  \
    --check-empty-runs False \
    --ncols 4 \
    --ncols-legend 3 \
    --output-filename ppo_continuous_action_envpool_xla_jax_scan \
    --scan-history

ppo_continuous_action_envpool_xla_jax_scan

The hyperparams used is in https://github.com/vwxyzjn/envpool-cleanrl/blob/880552a168c08af334b5e5d8868bfbe5ea881445/ppo_continuous_action_envpool.py#L40-L75, with no obs nor reward normalization.

Thanks for updating me. I'll be ready for the code review once the documentation is finished. I'm also happy to run the experiment you suggested.

quangr commented 1 year ago

I have document the questions you brought up and I am now ready for the code review. I would be happy to hear your feedback.