vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
4.91k stars 566 forks source link

Add gymnasium support for DQN #370

Closed vcharraut closed 1 year ago

vcharraut commented 1 year ago

Description

This PR updates the DQN files to the lastest version of gymnasium, replacing gym.

Types of changes

Checklist:

If you need to run benchmark experiments for a performance-impacting changes:

Regression report

python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
        'dqn_atari_jax?tag=rlops-pilot' \
        'dqn_atari_jax?tag=pr-370-atari-jax' \
    --env-ids Breakout-v5 BeamRider-v5 Pong-v5 \
    --check-empty-runs False \
    --ncols 5 \
    --ncols-legend 2 \
    --output-filename figures/0compare \
    --scan-history \
    --report
────────────────────────────────────────────────────────────────────────────────────── Runtime (m) (mean ± std) ──────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Environment  ┃ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['rlops-pilot']}) ┃ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['pr-370-atari-jax']}) ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Breakout-v5  │ 270.1473263136972                                                │ 538.7802477303775                                                     │
│ BeamRider-v5 │ 271.7741639644951                                                │ 538.6782197420808                                                     │
│ Pong-v5      │ 261.6593977599932                                                │ 522.4641281567034                                                     │
└──────────────┴──────────────────────────────────────────────────────────────────┴───────────────────────────────────────────────────────────────────────┘
──────────────────────────────────────────────────────────────────────────────────── Episodic Return (mean ± std) ────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Environment  ┃ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['rlops-pilot']}) ┃ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['pr-370-atari-jax']}) ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ Breakout-v5  │ 365.77 ± 15.64                                                   │ 356.66 ± 5.64                                                         │
│ BeamRider-v5 │ 5888.53 ± 185.09                                                 │ 6058.41 ± 116.74                                                      │
│ Pong-v5      │ 20.39 ± 0.17                                                     │ 20.39 ± 0.02                                                          │
└──────────────┴──────────────────────────────────────────────────────────────────┴───────────────────────────────────────────────────────────────────────┘
──────────────────────────────────────────────────────────────────────────────────────── Runtime (m) Average ─────────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Environment                                                           ┃ Average Runtime   ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['rlops-pilot']})      │ 267.8602960127285 │
│ openrlbenchmark/cleanrl/dqn_atari_jax ({'tag': ['pr-370-atari-jax']}) │ 533.3075318763872 │
└───────────────────────────────────────────────────────────────────────┴───────────────────┘
image

https://wandb.ai/costa-huang/cleanrl/reports/Regression-Report-dqn_atari_jax--Vmlldzo0MjQ5OTA2

python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=cleanrl&ceik=env_id&cen=exp_name&metric=charts/episodic_return' \
        'dqn?tag=pr-370' \
        'dqn_jax?tag=pr-370-jax' \
        'dqn?tag=rlops-pilot' \
        'dqn_jax?tag=rlops-pilot' \
    --env-ids CartPole-v1 Acrobot-v1 MountainCar-v0 \
    --check-empty-runs False \
    --ncols 3 \
    --ncols-legend 2 \
    --output-filename figures/0compare \
    --scan-history \
    --report
────────────────────────────────────────────────────────────────────────────────────── Runtime (m) (mean ± std) ──────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                ┃ openrlbenchmark/cleanrl/dqn ({'tag':       ┃ openrlbenchmark/cleanrl/dqn_jax ({'tag':   ┃ openrlbenchmark/cleanrl/dqn ({'tag':       ┃ openrlbenchmark/cleanrl/dqn_jax ({'tag':   ┃
┃ Environment    ┃ ['pr-370']})                               ┃ ['pr-370-jax']})                           ┃ ['rlops-pilot']})                          ┃ ['rlops-pilot']})                          ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CartPole-v1    │ 3.099431800075442                          │ 1.8901799905559769                         │ 2.229200170565302                          │ 2.0570977331846896                         │
│ Acrobot-v1     │ 4.185325574186605                          │ 3.3383588646594835                         │ 3.2403913728341207                         │ 3.005497937894226                          │
│ MountainCar-v0 │ 3.5431891388538053                         │ 2.2788801149391746                         │ 2.5699978012313105                         │ 2.3790336879432625                         │
└────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┘
──────────────────────────────────────────────────────────────────────────────────── Episodic Return (mean ± std) ────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃                ┃ openrlbenchmark/cleanrl/dqn ({'tag':       ┃ openrlbenchmark/cleanrl/dqn_jax ({'tag':   ┃ openrlbenchmark/cleanrl/dqn ({'tag':       ┃ openrlbenchmark/cleanrl/dqn_jax ({'tag':   ┃
┃ Environment    ┃ ['pr-370']})                               ┃ ['pr-370-jax']})                           ┃ ['rlops-pilot']})                          ┃ ['rlops-pilot']})                          ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ CartPole-v1    │ 486.82 ± 8.32                              │ 324.99 ± 212.99                            │ 486.82 ± 8.32                              │ 499.26 ± 1.05                              │
│ Acrobot-v1     │ -90.20 ± 1.84                              │ -90.81 ± 1.94                              │ -90.20 ± 1.84                              │ -90.44 ± 0.99                              │
│ MountainCar-v0 │ -194.73 ± 7.30                             │ -191.72 ± 9.33                             │ -194.73 ± 7.30                             │ -169.26 ± 23.75                            │
└────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┴────────────────────────────────────────────┘
──────────────────────────────────────────────────────────────────────────────────────── Runtime (m) Average ─────────────────────────────────────────────────────────────────────────────────────────
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓
┃ Environment                                                ┃ Average Runtime    ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩
│ openrlbenchmark/cleanrl/dqn ({'tag': ['pr-370']})          │ 3.6093155043719505 │
│ openrlbenchmark/cleanrl/dqn_jax ({'tag': ['pr-370-jax']})  │ 2.502472990051545  │
│ openrlbenchmark/cleanrl/dqn ({'tag': ['rlops-pilot']})     │ 2.679863114876911  │
│ openrlbenchmark/cleanrl/dqn_jax ({'tag': ['rlops-pilot']}) │ 2.4805431196740595 │
└────────────────────────────────────────────────────────────┴────────────────────┘
image

https://wandb.ai/costa-huang/cleanrl/reports/Regression-Report-dqn_jax--Vmlldzo0MjUwMDM1

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add feedback May 2, 2023 8:03pm
pseudo-rnd-thoughts commented 1 year ago

The error is due to needing stable baselines 3 ==2

vwxyzjn commented 1 year ago

No sign of regression, as shown in the PR description. Merging now.

ronuchit commented 1 year ago

Hi @vwxyzjn @charraut, I'm wondering what part of this change forced us to add the following line: assert args.num_envs == 1, "vectorized envs are not supported at the moment"

Vectorization was a useful feature earlier. Thank you!

vwxyzjn commented 1 year ago

@ronuchit this is due to SB3's replay buffer don't support num_envs>1 I think.

ronuchit commented 1 year ago

I believe it does, actually: https://github.com/DLR-RM/stable-baselines3/blame/master/stable_baselines3/common/buffers.py#L162

We would just need to pass in n_envs=args.num_envs when we instantiate the ReplayBuffer. Perhaps there are other issues at play here?

vwxyzjn commented 1 year ago

I believe it does, actually: https://github.com/DLR-RM/stable-baselines3/blame/master/stable_baselines3/common/buffers.py#L162

We would just need to pass in n_envs=args.num_envs when we instantiate the ReplayBuffer. Perhaps there are other issues at play here?

I see. That’s interesting. Would you be interested in making a PR that optionally supports num_envs>1?

ronuchit commented 1 year ago

sure, done: https://github.com/vwxyzjn/cleanrl/pull/395