vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.02k stars 575 forks source link

Qdagger: Reincarnate RL #344

Closed vwxyzjn closed 1 year ago

vwxyzjn commented 1 year ago

Description

https://github.com/google-research/reincarnating_rl

Preliminary result

image

Need more contributors on this.

Types of changes

Checklist:

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ā†—ļøŽ

Name Status Preview Comments Updated (UTC)
cleanrl āœ… Ready (Inspect) Visit Preview šŸ’¬ Add feedback Jun 9, 2023 0:53am
vwxyzjn commented 1 year ago

Starting a thread on this: @richa-verma has expressed interest in helping out with this PR. Welcome, Richa! I will try to put some information below to help you get started, and happy to help further with anything you need.

The main things we are looking for are 1) single file implementations (minimal lines of code), 2) documentation explaining notable implementation details, 3) benchmarking and matching the performance of reference implementations. Please check out our contribution guide for the usual process, and https://github.com/vwxyzjn/cleanrl/pull/331 is a good example of how new algorithms are contributed end-to-end.

With that said, let me share more detail on the current status of this PR.

model loading: As you know, Reincarnate RL relies on prior models for training. Luckily, we already have pre-trained models on huggingface with #292. See the docs for more detail, and the colab notebook has a good demo on how to load the models.

jax vs pytorch we have both jax and PyTorch-trained models on github for DQN and atari. Feel free to work with what you prefer more.

qdagger: I have implemented qdagger_dqn_atari_jax_impalacnn.py (which uses JAX) as a proof-of-concept. Its rough flow looks as follows:

  1. Load the teacher model from huggingface https://github.com/vwxyzjn/cleanrl/blob/87a30fb1e900d83a5e2e6e18fabab316515c142e/cleanrl/qdagger_dqn_atari_jax_impalacnn.py#L226-L233
  2. Evaluate the teacher model to get its average episodic return G_T, which will be useful later https://github.com/vwxyzjn/cleanrl/blob/87a30fb1e900d83a5e2e6e18fabab316515c142e/cleanrl/qdagger_dqn_atari_jax_impalacnn.py#L235-L246
  3. Then, since the pre-trained models do not contain the replay buffer data, we need to populate the replay buffer for the teacher. See A.5 Additional ablations for QDagger in the original paper for more detail https://github.com/vwxyzjn/cleanrl/blob/87a30fb1e900d83a5e2e6e18fabab316515c142e/cleanrl/qdagger_dqn_atari_jax_impalacnn.py#L248-L275
  4. The rest is to perform the offline phase (e.g., "reincarnate steps")https://github.com/vwxyzjn/cleanrl/blob/87a30fb1e900d83a5e2e6e18fabab316515c142e/cleanrl/qdagger_dqn_atari_jax_impalacnn.py#L307-L340
  5. and online phase https://github.com/vwxyzjn/cleanrl/blob/87a30fb1e900d83a5e2e6e18fabab316515c142e/cleanrl/qdagger_dqn_atari_jax_impalacnn.py#L342-L416

Some further considerations & optimizations:

  1. Atari preprocessing: we have used an old set of preprocessing techniques that doesn't use sticky action, but the original paper does. The exact difference I think is highlighted here. We have some possible options to reproduce this work: 1) we can use the current set of Atari preprocessing and just reproduce the algorithm, and 2) we can run another set of benchmarks with preprocessing techniques that are aligned with the original paper, save the models to huggingface, then load these models for our reproductions, 3) possibly we can somehow take the trained checkpoints from the original paper and figure out a way to load them, but this is likely extremely ad-hoc, and I would not recommend it.
  2. Step number 3 could be sped up by leveraging multiple simulation environments.
  3. In step 5, my implementation is to directly substitute the student's replay buffer with the teacher's. Not exactly sure if this is correct... Not sure if we should build the student's replay buffer from scratch.
  4. in step 5, we could optionally add a threshold at which we no longer take any distillation from the teacher policy.

I know this is throwing a lot at you. Please let me know if you need further clarifications or pair programming :) Thanks for your interest in working on this again.

sdpkjc commented 1 year ago

3. In step 5, my implementation is to directly substitute the student's replay buffer with the teacher's. Not exactly sure if this is correct... Not sure if we should build the student's replay buffer from scratch.

In this part, my understanding is that the teacher buffer and the student buffer should be distinguished. I see that Section 4.1 of the original paper mentions the symbols of the two buffers D_T and D_S. The implementation of the original paper code does the same thing, It can be obtained from https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/reincarnating_rl/reincarnation_dqn_agent.py#LL147C1-L186C35 is proved.ļ¼ˆQDaggerDQNAgent inherits from ReincarnationDQNAgentļ¼‰ć€‚

ReincarnationDQNAgent inherits from dopamine.jax.agents.dqn.dqn_agent.JaxDQNAgent. As you can see from the dopamine repository code, the agent creates a single buffer. https://github.com/google/dopamine/blob/81f695c1525f2774fbaa205cf19d60946b543bc9/dopamine/jax/agents/dqn/dqn_agent.py#L334

vwxyzjn commented 1 year ago

In this part, my understanding is that the teacher buffer and the student buffer should be distinguished.

This is correct. I used the same buffer because the teacher's buffer was not saved in the hugging face's model. Then we can populate the teacher's buffer, according to "A.5 Additional ablations for QDagger".

Would you be interested in taking on this PR?

sdpkjc commented 1 year ago

I would be glad to take on this PR. šŸ˜„ Then, I plan to perfect step 5 first, by implementing the buffer of the independent student agent, and by comparing the code of the original paper to perfect the weaning of the student agent.

sdpkjc commented 1 year ago

I observed some strange bugs in the latest version of the original code. When I looked at the init commit from git, it seemed a bit more correct. I suggest using files distillation_dqn_agent.py and persistent_dqn_agent.py as a reference to implement our code.

sdpkjc commented 1 year ago

TODO: question: do we need to use epsilon greedy here?

Yes, we need to use epsilon-greedy here. The reason is that ReincarnationDQNAgent sets its epsilon_fn to the reincarnation_linearly_decaying_epsilon function.

sdpkjc commented 1 year ago

jax_step torch_step

vwxyzjn commented 1 year ago

The results look really good! Great job @sdpkjc. I noticed the learning curves looked slightly different... Any ideas? Maybe it could be explained by that the teacher model in dqn_atari has 333.60 +/- 120.61 score whereas dqn_atari_jax has 291.10 +/- 116.43? Also, feel free to test out Pong and BeamRider.

vwxyzjn commented 1 year ago

You should also do a filediff to minimize the lines of different code. E.g., the comment

    # we assume we don't have access to the teacher's replay buffer
    # see Fig. A.19 in Agarwal et al. 2022 for more detail

should be in both variants :)

image
sdpkjc commented 1 year ago

I will be overhauling two scripts recently, including minimizing differences, sorting variable names, and confirming default parameter values. After that, I will run experiments in the Breakout, Pong, and BeamRider environments.

I also noticed that the performance of the two teacher models is different, and what's more important is that their avg_episodic_return gap is very large. Moreover, I found that the training loss of the two scripts is quite different, which may also be related to this.

image image
sdpkjc commented 1 year ago

I found that jax script did't update the target_network during the offline phase! I fixed the bug and will be reruning a set of experiments soon.

sdpkjc commented 1 year ago

compare

vwxyzjn commented 1 year ago

The results look really good!!! The code looks good too. Feel free to start preparing the docs :)

vwxyzjn commented 1 year ago

jax_step torch_step

Oh I guess maybe this is one last thing. Could you add this to the docs as well? This way the user can see the comparison.

sdpkjc commented 1 year ago

šŸ‘Œ

vwxyzjn commented 1 year ago

Btw you don't have to generate plots like these (but now that you have them already, it's perfectly fine to leave them there). We used to generate these plots because we had to do it manually, but now we can just use the openrlbenchmark utility to generate them :)

image
sdpkjc commented 1 year ago

Thanks, I have generated the plots of qdagger vs dqn using openrlbenchmark and added the comparison in our wandb report.

image image