vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.02k stars 575 forks source link

Reproduction of Muesli #350

Closed vwxyzjn closed 8 months ago

vwxyzjn commented 1 year ago

Problem Description

Muesli is a next-generation policy gradient algorithm from DeepMind that performs exceptionally well. Notably, it can match MuZero’s SOTA results in Atari without using deep searches such as MCTS. In addition to a more robust objective, Muesli also has a model and can also handle off-policy data and value inaccuracies.

It would be incredibly useful to reproduce Muesli in CleanR, making Muesli more accessible and easy to use. One possible application is to use it to fine-tune LLM, which can help https://github.com/CarperAI/trlx (cc @LouisCastricato). I also think being able to handle off-policy data will make Muesli of special interest for human-in-the-loop RL (cc @cloderic, @saikrishna-1996), given that muesli is like a PPO but can more gracefully deal with off-policy data, such as retroactive rewards enabled by cogment.

With that said, this issue describes a roadmap to reproduce Muesli and associated challenges and opportunities.

Reproduction Analysis

Muesli does look really impressive, and I like the paper a lot! However, I'd expect reproduction to be challenging for two main reasons. First, Muesli is not open-sourced, and there are no other independent reference implementations, so we're going to do this from scratch. Second, Muesli uses DeepMind's Podracer Architecture, which is arguably harder to reproduce.

Roadmap

I think the best way to reproduce it is to work on a semi-synchronous version of Muesli, similar to how OpenAI reproduced A3C as a synchronous A2C first. "semi-synchronous" means to use EnvPool to do rollouts asynchronously and do learning synchronously. This kind of architecture has many benefits: easier to reason, pretty efficient, and can apply to LLM tuning in a fairly straightforward way. We then test and iterate development on Atari. The rough steps are as follows:

I will dive into a bit of detail in the next section.

Baseline

We first need to understand the Atari baseline. Muesli's Atari setup can be summarized as follows:

envs = envpool.make(
    env_id,
    env_type="gym",
    num_envs=num_envs,
    batch_size=async_batch_size,
    stack_num=4, # Hessel et al 2022, Muesli paper, Table 10
    img_height=96, # Hessel et al 2022, Muesli paper, Table 4
    img_width=96, # Hessel et al 2022, Muesli paper, Table 4
    episodic_life=False,  # Hessel et al 2022, Muesli paper, Table 4
    repeat_action_probability=0.25,  # Hessel et al 2022, Muesli paper, Table 4
    noop_max=0,  # Hessel et al 2022, Muesli paper, Table 4
    full_action_space=True,  # Hessel et al 2022, Muesli paper, Table 4
    max_episode_steps=int(108000 / 4),  # Hessel et al 2022, Muesli paper, Table 4, we divide by 4 because of the skipped frames
    reward_clip=True,
    seed=seed,
)

Its agent then uses IMPALA CNN (64, 128, 128, 64) with LSTM, and Muesli and its PPO baseline reach ~562.00% and ~300% median HNS (human-normalized score) in 200M frames (50M steps), respectively.

My current best setting to replicate Muesli's PPO is as follows (#338):

 envs = envpool.make(
     env_id,
     env_type="gym",
     num_envs=num_envs,
     batch_size=async_batch_size,
     stack_num=4, # Hessel et al 2022, Muesli paper, Table 10
+    img_height=86, # Hessel et al 2022, Muesli paper, Table 4
+    img_width=86, # Hessel et al 2022, Muesli paper, Table 4
     episodic_life=False,  # Hessel et al 2022, Muesli paper, Table 4
     repeat_action_probability=0.25,  # Hessel et al 2022, Muesli paper, Table 4
+    noop_max=1,  # Hessel et al 2022, Muesli paper, Table 4
     full_action_space=True,  # Hessel et al 2022, Muesli paper, Table 4
     max_episode_steps=int(108000 / 4),  # Hessel et al 2022, Muesli paper, Table 4, we divide by 4 because of the skipped frames
     reward_clip=True,
     seed=seed,
 )

338 used 86x86 image size to be consistent with existing work such as IMPALA, noop_max=1 because https://github.com/sail-sg/envpool/issues/234. I also used IMPALA CNN (16, 32, 32) because my GPU 3060 TI can only fit (16, 32, 32) and I did not bother to implement LSTM yet... My reproduction in #338 yields

pip install openrlbenchmark --upgrade
# expect the following command to run for hours
python -m openrlbenchmark.rlops \
    --filters '?we=openrlbenchmark&wpn=envpool-atari&ceik=env_id&cen=exp_name&metric=charts/avg_episodic_return' 'ppo_atari_envpool_xla_jax_truncation_machado' 'ppo_atari_envpool_async_jax_scan_impalanet_machado'  \
    --env-ids Alien-v5 Amidar-v5 Assault-v5 Asterix-v5 Asteroids-v5 Atlantis-v5 BankHeist-v5 BattleZone-v5 BeamRider-v5 Berzerk-v5 Bowling-v5 Boxing-v5 Breakout-v5 Centipede-v5 ChopperCommand-v5 CrazyClimber-v5 Defender-v5 DemonAttack-v5 DoubleDunk-v5 Enduro-v5 FishingDerby-v5 Freeway-v5 Frostbite-v5 Gopher-v5 Gravitar-v5 Hero-v5 IceHockey-v5 PrivateEye-v5 Qbert-v5 Riverraid-v5 RoadRunner-v5 Robotank-v5 Seaquest-v5 Skiing-v5 Solaris-v5 SpaceInvaders-v5 StarGunner-v5 Surround-v5 Tennis-v5 TimePilot-v5 Tutankham-v5 UpNDown-v5 Venture-v5 VideoPinball-v5 WizardOfWor-v5 YarsRevenge-v5 Zaxxon-v5 Jamesbond-v5 Kangaroo-v5 Krull-v5 KungFuMaster-v5 MontezumaRevenge-v5 MsPacman-v5 NameThisGame-v5 Phoenix-v5 Pitfall-v5 Pong-v5 \
    --check-empty-runs False \
    --ncols 5 \
    --ncols-legend 2 \
    --output-filename machado_50M_impala \
    --scan-history

python -m openrlbenchmark.hns --files machado_50M_impala.csv 

# outputs:
openrlbenchmark/envpool-atari/ppo_atari_envpool_xla_jax_truncation_machado ({})
┣━━ median hns: 1.5679929625118418
┣━━ mean hns: 8.352308370550299
openrlbenchmark/envpool-atari/ppo_atari_envpool_async_jax_scan_impalanet_machado ({})
┣━━ median hns: 1.5167741935483872
┣━━ mean hns: 11.038219990985528

machado_50M_impala

openrlbenchmark/envpool-atari/ppo_atari_envpool_xla_jax_truncation_machado ({}) openrlbenchmark/envpool-atari/ppo_atari_envpool_async_jax_scan_impalanet_machado ({})
Alien-v5 2626.45 ± 0.00 4354.95 ± 0.00
Amidar-v5 1323.18 ± 0.00 1026.85 ± 0.00
Assault-v5 3225.26 ± 0.00 5073.31 ± 0.00
Asterix-v5 11081.75 ± 0.00 333898.08 ± 0.00
Asteroids-v5 1788.17 ± 0.00 10818.26 ± 0.00
Atlantis-v5 775537.50 ± 0.00 875300.00 ± 0.00
BankHeist-v5 1172.79 ± 0.00 1167.93 ± 0.00
BattleZone-v5 33153.75 ± 0.00 32018.12 ± 0.00
BeamRider-v5 3161.69 ± 0.00 9482.59 ± 0.00
Berzerk-v5 814.51 ± 0.00 633.49 ± 0.00
Bowling-v5 50.89 ± 0.00 29.95 ± 0.00
Boxing-v5 96.49 ± 0.00 99.59 ± 0.00
Breakout-v5 373.85 ± 0.00 470.53 ± 0.00
Centipede-v5 3082.35 ± 0.00 3854.42 ± 0.00
ChopperCommand-v5 12176.00 ± 0.00 22109.44 ± 0.00
CrazyClimber-v5 135736.62 ± 0.00 124531.36 ± 0.00
Defender-v5 57146.62 ± 0.00 68088.55 ± 0.00
DemonAttack-v5 12115.09 ± 0.00 63779.84 ± 56545.29
DoubleDunk-v5 -0.68 ± 0.00 -0.23 ± 0.00
Enduro-v5 1734.94 ± 0.00 2334.03 ± 0.00
FishingDerby-v5 42.35 ± 0.00 39.52 ± 0.00
Freeway-v5 33.63 ± 0.00 33.61 ± 0.00
Frostbite-v5 269.59 ± 0.00 269.77 ± 0.00
Gopher-v5 16318.98 ± 0.00 21028.99 ± 0.00
Gravitar-v5 2695.44 ± 0.00 512.07 ± 0.00
Hero-v5 33900.23 ± 0.00 13957.64 ± 0.00
IceHockey-v5 -4.38 ± 0.00 1.26 ± 0.00
PrivateEye-v5 72.25 ± 0.00 0.00 ± 0.00
Qbert-v5 22940.53 ± 0.00 18828.40 ± 0.00
Riverraid-v5 11789.25 ± 0.00 21917.63 ± 0.00
RoadRunner-v5 57660.25 ± 0.00 44616.17 ± 0.00
Robotank-v5 27.58 ± 0.00 28.30 ± 0.00
Seaquest-v5 1882.95 ± 0.00 955.04 ± 0.00
Skiing-v5 -29998.00 ± 0.00 -12812.83 ± 0.00
Solaris-v5 2067.62 ± 0.00 1913.79 ± 0.00
SpaceInvaders-v5 2849.32 ± 0.00 34592.05 ± 0.00
StarGunner-v5 34239.50 ± 0.00 116080.91 ± 0.00
Surround-v5 6.12 ± 0.00 3.36 ± 0.00
Tennis-v5 -0.35 ± 0.00 -0.29 ± 0.00
TimePilot-v5 10997.38 ± 0.00 31976.45 ± 0.00
Tutankham-v5 305.59 ± 0.00 226.72 ± 0.00
UpNDown-v5 263615.69 ± 0.00 359771.62 ± 0.00
Venture-v5 0.00 ± 0.00 0.00 ± 0.00
VideoPinball-v5 412265.15 ± 0.00 433614.56 ± 0.00
WizardOfWor-v5 11283.50 ± 0.00 7176.12 ± 0.00
YarsRevenge-v5 97739.87 ± 0.00 97599.37 ± 0.00
Zaxxon-v5 16688.62 ± 0.00 0.00 ± 0.00
Jamesbond-v5 522.06 ± 0.00 643.03 ± 0.00
Kangaroo-v5 14603.50 ± 0.00 14197.66 ± 0.00
Krull-v5 9884.79 ± 0.00 8577.02 ± 0.00
KungFuMaster-v5 31035.50 ± 0.00 34155.31 ± 0.00
MontezumaRevenge-v5 0.00 ± 0.00 0.00 ± 0.00
MsPacman-v5 4838.56 ± 0.00 4524.04 ± 0.00
NameThisGame-v5 11958.65 ± 0.00 13315.72 ± 0.00
Phoenix-v5 5685.30 ± 0.00 45747.97 ± 0.00
Pitfall-v5 0.00 ± 0.00 -15.12 ± 0.00
Pong-v5 16.10 ± 0.00 20.19 ± 0.00

The performance does not match the reported performance in Muesli's PPO, and here are some ideas to help us match the baseline:

Prototype

In any case, our PPO in #338 can be a good starting place to implement the semi-synchronous version of Muesli. I'd suggest clone ppo_atari_envpool_async_jax_scan_impalanet_machado.py and create a new file called muesli_atari_envpool_async_jax_scan_machado.py, try implementing muesli and iterate experiments on Breakout-v5 to see if we can replicate the game score of 791. If that's successful we can run a more comprehensive benchmark, and proceed with our contribution process.

Checklist

Current Behavior

Expected Behavior

Possible Solution

Steps to Reproduce

1. 2. 3. 4.

shermansiu commented 1 year ago

Actually, there's an implementation of Muesli in https://github.com/YuriCat/MuesliJupyterExample, though I'm not sure how good it is.

And Muesli is just MPO + a learned latent dynamics/prediction network model (predicts $r$, $V$, and $\pi$), so we can refer to implementations of MPO (https://github.com/daisatojp/mpo, https://github.com/theogruner/rl_pro_telu) and MuZero, I guess.

shermansiu commented 1 year ago

This means that improvements from EfficientZero could hypothetically transfer over to Muesli's learned world model, but this would need to be verified experimentally.

xrsrke commented 1 year ago

@vwxyzjn I have just started reading the paper, do you have a plan for when you want to complete it? I will try to finish it on time

vwxyzjn commented 1 year ago

@shermansiu, thanks for the reference! They look like valuable resources. That said, I would like to see more benchmark information. For example, https://github.com/YuriCat/MuesliJupyterExample has a proof-of-concept on Tic-tac-toe, but it might not necessarily work for Atari.

@xrsrke, thanks for your interest! There is no specific timeline at the moment, but I would encourage folks to post related updates here for transparency.

shermansiu commented 1 year ago

Anyways, I've started work on this.

vwxyzjn commented 1 year ago

Btw I came across @hr0nix's https://github.com/hr0nix/dejax. It may come in handy for muesli's replay buffer.

Howuhh commented 1 year ago

FIY: New work from deepmind about online meta-RL also uses Muesli as a base RL algorithm, which is interesting https://sites.google.com/view/adaptive-agent/

shermansiu commented 1 year ago

Interesting! Taking a look.

shermansiu commented 1 year ago

Also, interestingly enough, Torchbeast does not use an LSTM in the Impala network and was able to match or exceed the performance of Impala.

vwxyzjn commented 1 year ago

@Howuhh, thanks for the reference! @shermansiu, note that there are some caveats.

  1. The torchbeast paper's tfimpala is modified from the deepmind/scalable_agent, however tfimpala does not necessarily reproduce the same level of performance as reported in the IMPALA paper. Using Breakout as an example, the IMPALA paper reports 640.43 score for the shallow model, and tfimpala reports ~150 score. It's possible that Deepmind used a different codebase for Atari experiments, it's also possible that different Atari simulators were used which caused the problem (e.g., Deepmind seemed to have used xitari internally).
  2. To more quantitatively measure the performance, we need to use normalized human scores. moolib performs better than torchbeast, but when measured in human-normalized scores, we see that it is a bit lower than expected. See https://github.com/facebookresearch/moolib/issues/30#issuecomment-1040376523 for more detail.
  3. Note that there are additional reproduction issues https://github.com/facebookresearch/torchbeast/issues/37 https://github.com/facebookresearch/torchbeast/issues/25 with monobeast, which is different from polybeast.
shermansiu commented 1 year ago

It turns out that dejax was insufficient for sampling sequences and I couldn't find an existing, alternate implementation. So I implemented my own version.

Using a naive replay buffer that doesn't reorganize updates according to the environment makes sampling sequences difficult. Plus, it makes the implementation of model rollouts (length 5) over a sequence of length 30 difficult...

Surprisingly, implementing a replay buffer that could support envpool's asynchronous updates took longer than expected. The implementation is fully vectorized and mostly jitted, which I am proud of, but sadly, it uses SRSWR, which can increase the variance of the estimator. But I have something ready, with tests.

hr0nix commented 1 year ago

@shermansiu What kind of functionality are you currently missing in dejax? Perhaps that's is something I can add relatively easily.

Do I understand correctly that you need to sample fixed-length chunks of the long trajectories storied in the replay buffer?

vwxyzjn commented 1 year ago

Glad to hear that @shermansiu! Feel free to put your current code in a PR — I might be able to help review or add suggestions.

Are you looking to implement LSTM?

Also, thanks for the comment @hr0nix!

shermansiu commented 1 year ago

@vwxyzjn Yeah, sure! I think I'm almost done at this point: I'm trying to get something out later today.

@hr0nix, here are the following things I wanted but couldn't get out of dejax:

Here's my implementation of the replay buffer. And again, the PR should be coming out later today (fingers crossed!) https://gist.github.com/shermansiu/b492fddf4127f4214d57a647c0160b8f

shermansiu commented 1 year ago

@vwxyzjn It took longer than expected, but I made the PR request! I still need to debug and test some things, but at least the code is available. And yes, this version uses an LSTM in the representation network!

hr0nix commented 1 year ago

Hi @shermansiu, can you elaborate on it slightly?

Sampling sequences

dejax currently samples the objects stored in the replay buffer, which might themselves be trajectory chunks (that's how I've been using it). If this doesn't work for you, I assume that you need to store states, but sample segments of consequent states? Can you store trajectories instead, sample a trajectory and then sample a subsegment of that trajectory?

jitted operations

All replay buffer operations in dejax can be jitted and there are tests to ensure that it works.

Fully vectorized sampling

Sampling can be vectorized simply by using vmap.

shermansiu commented 1 year ago

@hr0nix

Sampling sequences

I tried storing trajectories, but they have different shapes, which isn't supported by dejax (see utils.assert_tree_is_batch_of_tree, which is called by circular_buffer:push). This was the main problem I faced. The other points are mainly minor nitpicks. Trying to add vector-padding in an efficient manner is challenging when using envpools asynchronous mode, which is why I didn't use dejax. But I'm welcome to suggestions to how to use it to store vectors of different length!

I'm sure I could use dejax.clustered if more features were added and I'm sure we can refactor the code to use it later, once that happens.

jitted operations

Good to know, thanks!

Fully-vectorized sampling

I meant having operations that are parameterized by native Jax vector operations as opposed to being wrapped by vmap or jax.lax.scan. Once again, this is a nitpick and possibly premature optimization.

Sampling from the same data source, with two different buffer heads

Also, the replay buffer and the online queue in my implementation have the same data source, which would potentially be hacky for a general-purpose buffer?

nit: Buffer operations aren't methods of the buffer state itself. Even jnp ndarrays have convenient methods.

In my implementation, all of the functions are implemented as methods of the buffer. dejax's implementation could be changed to use this, but then the API would not be backwards compatible.

nit: The buffer methods in dejax have the suffix _fn.

This makes it inconsistent with the naming conventions used by the rest of the Jax ecosystem, where the functions are simply init or update.


TLDR The main reason I didn't use dejax was because trajectories have different shapes and it was difficult to use with envpool's asynchronous API (that was one of the first things I tried). I'm welcome to more suggestions on how to use it with dejax, as I'm not a fan of re-inventing the wheel.

All in all, I'd say dejax is a pretty good package and shoutout to @hr0nix for his awesome work on it!

hr0nix commented 1 year ago

Thanks for clarifying! I'll think how to address the issues you listed.

shermansiu commented 1 year ago

Well, I have an implementation that runs. I'm just not sure if the returns are normal though (about 1-2)?

shermansiu commented 1 year ago

At least the returns are steadily going up.

shermansiu commented 1 year ago

Figuring out why the r_loss and v_loss variables are nan.

shermansiu commented 1 year ago

Solved the loss nan issues, as well as a few other loss-related bugs.

Now, the reward, value, and policy model losses remain relatively constant, and the CMPO regularization term goes up slightly. Just the policy gradient loss goes down. The returns are still about 1-2.

vwxyzjn commented 1 year ago

FWIW, I have been playing with a prototype of the podracer architecture used in the Muesli. Might come in handy if this prototype was successful, and the we can port #354 to it.