google-deepmind / open_spiel

OpenSpiel is a collection of environments and algorithms for research in general reinforcement learning and search/planning in games.
Apache License 2.0
4.16k stars 917 forks source link

MuZero implementation using OpenSpiel #595

Closed uduse closed 1 year ago

uduse commented 3 years ago

First of all, I want to thank the developers for this awesome project! It's simple, clean yet powerful. I really enjoyed playing with it.

I'm currently studying at the University of Alberta under the supervision of Prof. Martin Mueller. My primary research focus is learning/planning with a model, and general game playing. MuZero was a big step in this direction and I would like to implement an open-source version of it as a foundation of my project. I'm aware of other open-source implementations but I would like to have a more efficient and robust implementation. (This is partially why I opened #592, since utilizing cloud computation power well is a must-have.) There's also #135 but unfortunately no follow-up.

My plan is to implement MuZero in a separate repo first using both the C++ and Python interfaces of OpenSpiel. I'll use C++ for search (MuZero flavored MCTS) and Python for everything else. I'll use JAX for neural-net-related things (I heard that's what you are using in DeepMind for MuZero). If the project works, we can integrate the project into OpenSpiel at some point in the future.

Here are some questions of mine:

  1. What are the caveats of writing a MuZero flavored MCTS in C++ similar to OpenSpiel's own MCTS?
  2. I presume since the interfaces of the games are quite unified, the algorithm should work on all the games without too much tweaking (work in terms of running without errors, but not necessarily performing well on the task). Is this correct?
  3. What do you do to display/visualize information/metrics? Right now I'm just reading console logs.

It would be great if the developers can help me with these questions and any other tips regarding the project would be greatly appreciated! 😃

lanctot commented 3 years ago

Hi @uduse,

This sounds great :)

I'm aware of other open-source implementations but I would like to have a more efficient and robust implementation.

Do you know about muzero-general? They also recently added support for OpenSpiel games: https://github.com/werner-duvaud/muzero-general/commit/23a1f6910e97d78475ccd29576cdd107c5afefd2.

My plan is to implement MuZero in a separate repo first using both the C++ and Python interfaces of OpenSpiel. I'll use C++ for search (MuZero flavored MCTS) and Python for everything else. I'll use JAX for neural-net-related things (I heard that's what you are using in DeepMind for MuZero). If the project works, we can integrate the project into OpenSpiel at some point in the future.

Cool! Am I assuming correctly that you want to have the search/inference in C++ due to performance? (i.e. rather than a Python impl like muzero-general?) We don't have any examples of mixing C++ and JAX in the library yet (not even simple ones), but they'd be more than welcome!

1. What are the caveats of writing a MuZero flavored MCTS in C++ similar to OpenSpiel's own MCTS?

Sorry, I don't really understand.. can you elaborate? Is there something that you're concerned about in particular?

2. I presume since the interfaces of the games are quite unified, the algorithm should work on all the games without too much tweaking (work in terms of running without errors, but not necessarily performing well on the task). Is this correct?

I don't immediately see why you wouldn't be able to keep it general (muzero-general has, AFAIK?)

3. What do you do to display/visualize information/metrics? Right now I'm just reading console logs.

@tewalds wrote an analysis tool described at the bottom of https://github.com/deepmind/open_spiel/blob/master/docs/alpha_zero.md. I think @christianjans has used it as well.

uduse commented 3 years ago

Do you know about muzero-general? They also recently added support for OpenSpiel games: werner-duvaud/muzero-general@23a1f69.

Yes, and my project is heavily inspired by that. My project will be different in two main ways: (1) I'll use JAX (2) I'll try to make the algorithm efficient and scalable.

Cool! Am I assuming correctly that you want to have the search/inference in C++ due to performance?

Yes, similar to the MCTS in KataGo, I will try to implement a multi-threaded C++ version for MuZero.

Sorry, I don't really understand.. can you elaborate? Is there something that you're concerned about in particular?

I was wondering if there's anything on top of your head that you would like to tell me. Other than that, my only concern is getting JAX to work with C++. If I can't easily get C++ to inference the network, I have to use a separate Python inference worker to handle batch inferences from C++ threads. That said, before doing any of that, I need to benchmark muzero-general to see where's the Python self-play bottleneck. Maybe it will be sufficiently fast if I implement a Python multi-threaded (GIL warning ⛔ ) MCTS.

The analysis tool seems great! I will try to utilize it 👍

findmyway commented 3 years ago

FYI, a GSoC student will also work on MuZero in Julia and use the OpenSpiel wrapper this summer.

uduse commented 3 years ago

FYI, a GSoC student will also work on MuZero in Julia and use the OpenSpiel wrapper this summer.

Thanks for the information! Do you know this person's contact information? I cold-emailed a person with the same name but I likely got it wrong...

findmyway commented 3 years ago

Try this michal.lukomski21 through gmail. And @jonathan-laurent is the primary mentor.

jhtschultz commented 3 years ago

Newly released MuZero implementation that might be of interest: https://github.com/google-research/google-research/tree/master/muzero

xhevahir commented 3 years ago

Here's the GSOC page for the Julia muZero project: https://michelangelo21.github.io/gsoc/2021/08/23/gsoc-2021.html

NightMachinery commented 2 years ago

I am also interested in implementing MCTS and a semi-grad TD algorithm, to learn the material better. But figuring out the batching is a problem for me. Is there a JAX-based implementation of MCTS (or any kind of tree search that uses a neural network to estimate some feature of the nodes) I can see as a reference?

In general, is the idiomatic approach to use multi-threading and run several JAX JITed functions in parallel, or to create a suitable batch that is fed into a single JAX function?

lanctot commented 2 years ago

Hi @NightMachinary,

You can find MCTS that uses neural nets in our AlphaZero implementations: https://github.com/deepmind/open_spiel/blob/master/docs/alpha_zero.md

None are JAX-based. It would be great to have a JAX one too! Would make a nice contribution.

uduse commented 2 years ago

I am also interested in implementing MCTS and a semi-grad TD algorithm, to learn the material better. But figuring out the batching is a problem for me. Is there a JAX-based implementation of MCTS (or any kind of tree search that uses a neural network to estimate some feature of the nodes) I can see as a reference?

In general, is the idiomatic approach to use multi-threading and run several JAX JITed functions in parallel, or to create a suitable batch that is fed into a single JAX function?

My current implementation uses multiple copies of asynchronous MCTSs (using asyncio) and batch their queries together to inference using the NN. Each single MCTS still perform like a single-threaded MCTS, but combining multiple such MCTSs yields similar throughput as multi-threaded MCTS.

NightMachinery commented 2 years ago

@uduse

Can you link your implementation?

From what I understand, the key points are your design are that

1) MCTS only depends on the value/policy functions defined before it is run (which are not updated during the run), and its output is just the selection count for each node, which we can trivially sum for multiple runs.

2) JAX uses async dispatch, so using it with asyncio will result in an effectively parallel execution.

uduse commented 2 years ago

@NightMachinary

I just made my repo public, see here for my async MCTS.

It might not be what you want. My implementation focuses on increasing the throughput of multiple MCTSs but the latency of each individual MCTS is not reduced.

  1. MCTS only depends on the value/policy functions defined before it is run (which are not updated during the run), and its output is just the selection count for each node, which we can trivially sum for multiple runs.

Each MCTS outputs its own selection count, multiple runs yields multiple selection counts, which means multiple data points.

  1. JAX uses async dispatch, so using it with asyncio will result in an effectively parallel execution.

I don't think JAX's async dispatch has anything to do with concurrent MCTS.

Batching is indeed your core problem, but it's not that related to how JAX works. No matter what NN library you use, your batching logic will be separated from the NN inference somewhere else. This means the design of your MCTS only affects how you batch multiple NN inferences together. Once you have those batches, you can use any NN to do efficient inference. The NN in my project is JAX-based. However, I can swap it with a pytorch-based one without changing anything in my MCTS.

I have an implementation for the batching layer here. It's kinda verbose now and I'm planning to simplify it. Also see test here for example usage.

cmarlin commented 2 years ago

Hi. I wrote a MCTS using tensorflow's tensory only+batching (and XLA). https://gitlab.com/dia.group/tf-muzero So it's possible to the GPU doing the MCTS's stuff (whole MCTS in one gpu function call) ! It could be great to rewrite it to JAX. Difficulties:

uduse commented 2 years ago

@cmarlin There's already a JAX implementation here. They addressed the issues you mentioned above using a couple of tricks. It would be very interesting to make full JAX compatible agents in the future.

cmarlin commented 2 years ago

@cmarlin There's already a JAX implementation here. They addressed the issues you mentioned above using a couple of tricks. It would be very interesting to make full JAX compatible agents in the future.

Thanks for your link, I didn't know it!

lanctot commented 2 years ago

Tagging @tuero based on a recent reddit comment.

Also, I guess this is rather outdated. Is anybody working on an implementation they plan to contribute? Just wondering if we should still keep this issue open for discussion?

@uduse @cmarlin @NightMachinery

tuero commented 2 years ago

Hi @lanctot , a few months ago I did a full C++ implementation, moreso as a learning exercise (and possible extensions). I don't really have any future plans for the codebase, as I've moved onto slightly other methods for my research.

While the code is for the most part complete, I would have to make some changes for it to fit nicely with this repo which might not be trivial, but I am more than happy to do. I'm a bit busy at the moment, so this wouldn't be something for the near future, but can be made into a separate issue/PR once its reasonably close.

lanctot commented 2 years ago

@tuero As always, your contributions are super welcome :) But I completely understand the time investment and trade-offs... especially as a grad student!

I'll leave this thread open for now then in case people still want to still use it for discussion.

uduse commented 2 years ago

@lanctot I'm still working on the project. The project became more complicated than I envisioned and it would be extremely difficult to be ported as a part of OpenSpiel. That said, I am still using OpenSpiel games to train it and it would be a nice example of using OpenSpiel in a complicated DRL system.

lanctot commented 2 years ago

Ok, cool! That makes complete sense for an algorithm of this complexity, and especially since you are using it for research. Glad you could still use the game implementations. Please feel free to share any results if you want :)

lanctot commented 1 year ago

Closing due to inactivity. Please re-open if you want to continue the discussion.