Closed vwxyzjn closed 1 year ago
The latest updates on your projects. Learn more about Vercel for Git āļø
Name | Status | Preview | Comments | Updated (UTC) |
---|---|---|---|---|
cleanrl | ā Ready (Inspect) | Visit Preview | š¬ Add feedback | Jun 9, 2023 0:53am |
Starting a thread on this: @richa-verma has expressed interest in helping out with this PR. Welcome, Richa! I will try to put some information below to help you get started, and happy to help further with anything you need.
The main things we are looking for are 1) single file implementations (minimal lines of code), 2) documentation explaining notable implementation details, 3) benchmarking and matching the performance of reference implementations. Please check out our contribution guide for the usual process, and https://github.com/vwxyzjn/cleanrl/pull/331 is a good example of how new algorithms are contributed end-to-end.
With that said, let me share more detail on the current status of this PR.
model loading: As you know, Reincarnate RL relies on prior models for training. Luckily, we already have pre-trained models on huggingface with #292. See the docs for more detail, and the colab notebook has a good demo on how to load the models.
jax vs pytorch we have both jax and PyTorch-trained models on github for DQN and atari. Feel free to work with what you prefer more.
qdagger: I have implemented qdagger_dqn_atari_jax_impalacnn.py (which uses JAX) as a proof-of-concept. Its rough flow looks as follows:
Some further considerations & optimizations:
I know this is throwing a lot at you. Please let me know if you need further clarifications or pair programming :) Thanks for your interest in working on this again.
3. In step 5, my implementation is to directly substitute the student's replay buffer with the teacher's. Not exactly sure if this is correct... Not sure if we should build the student's replay buffer from scratch.
In this part, my understanding is that the teacher buffer and the student buffer should be distinguished. I see that Section 4.1 of the original paper mentions the symbols of the two buffers D_T and D_S. The implementation of the original paper code does the same thing, It can be obtained from https://github.com/google-research/reincarnating_rl/blob/a1d402f48a9f8658ca6aa0ddf416ab391745ff2c/reincarnating_rl/reincarnation_dqn_agent.py#LL147C1-L186C35 is proved.ļ¼QDaggerDQNAgent
inherits from ReincarnationDQNAgent
ļ¼ć
ReincarnationDQNAgent
inherits from dopamine.jax.agents.dqn.dqn_agent.JaxDQNAgent
. As you can see from the dopamine repository code, the agent creates a single buffer. https://github.com/google/dopamine/blob/81f695c1525f2774fbaa205cf19d60946b543bc9/dopamine/jax/agents/dqn/dqn_agent.py#L334
In this part, my understanding is that the teacher buffer and the student buffer should be distinguished.
This is correct. I used the same buffer because the teacher's buffer was not saved in the hugging face's model. Then we can populate the teacher's buffer, according to "A.5 Additional ablations for QDagger".
Would you be interested in taking on this PR?
I would be glad to take on this PR. š Then, I plan to perfect step 5 first, by implementing the buffer of the independent student agent, and by comparing the code of the original paper to perfect the weaning of the student agent.
I observed some strange bugs in the latest version of the original code. When I looked at the init commit from git, it seemed a bit more correct. I suggest using files distillation_dqn_agent.py
and persistent_dqn_agent.py
as a reference to implement our code.
TODO: question: do we need to use epsilon greedy here?
Yes, we need to use epsilon-greedy here.
The reason is that ReincarnationDQNAgent
sets its epsilon_fn to the reincarnation_linearly_decaying_epsilon
function.
The results look really good! Great job @sdpkjc. I noticed the learning curves looked slightly different... Any ideas? Maybe it could be explained by that the teacher model in dqn_atari
has 333.60 +/- 120.61 score whereas dqn_atari_jax
has
291.10 +/- 116.43? Also, feel free to test out Pong and BeamRider.
You should also do a filediff to minimize the lines of different code. E.g., the comment
# we assume we don't have access to the teacher's replay buffer
# see Fig. A.19 in Agarwal et al. 2022 for more detail
should be in both variants :)
I will be overhauling two scripts recently, including minimizing differences, sorting variable names, and confirming default parameter values. After that, I will run experiments in the Breakout, Pong, and BeamRider environments.
I also noticed that the performance of the two teacher models is different, and what's more important is that their avg_episodic_return
gap is very large. Moreover, I found that the training loss of the two scripts is quite different, which may also be related to this.
I found that jax script did't update the target_network
during the offline phase! I fixed the bug and will be reruning a set of experiments soon.
The results look really good!!! The code looks good too. Feel free to start preparing the docs :)
Oh I guess maybe this is one last thing. Could you add this to the docs as well? This way the user can see the comparison.
š
Btw you don't have to generate plots like these (but now that you have them already, it's perfectly fine to leave them there). We used to generate these plots because we had to do it manually, but now we can just use the openrlbenchmark
utility to generate them :)
Thanks, I have generated the plots of qdagger vs dqn using openrlbenchmark and added the comparison in our wandb report.
Description
https://github.com/google-research/reincarnating_rl
Preliminary result
Need more contributors on this.
Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.