Mononofu / mcts_benchmarks

Benchmarks of MCTS implementations
4 stars 0 forks source link

Simple vectorization for batches of bounded-depth MCTSs #1

Open Philip-Bachman opened 4 years ago

Philip-Bachman commented 4 years ago

You can get a nice speed boost for pure Python MCTS by vectorizing the logic for running N tree searches in parallel. The basic idea is to store the data describing, e.g., parent and child pointers describing the N trees in pre-allocated numpy/torch arrays and to run the current simulation for each of the N searches in lock step. Simulations can walk the trees in parallel using gather ops based on indices pulled from the arrays of child/parent pointers. Actions are selected via argmax over the appropriate axis of an array storing per-action scores for each tree node. You can handle variable length of the current simulation (for the N trees we're growing in parallel) by introducing a "null" node that acts as an absorbing state for the tree walking procedure that finds which node the current simluation will expand in each of the N trees.

Mononofu commented 4 years ago

Thanks for the suggestion!

I considered this option, but was a bit wary that it might make the core logic hard to read and modify. My intention was to investigate how performant a relatively pseudocode-like MCTS could be in Python, to make it easy to test modifications to the core algorithm.

Did you have good experiences with batched MCTS?

On Mon, 10 Feb 2020, 20:38 Philip Bachman, notifications@github.com wrote:

You can get a nice speed boost for pure python MCTS by vectorizing the logic for running N tree searches in parallel. The basic idea is to store the data describing, e.g., parent and child pointers describing the N trees in pre-allocated numpy/torch arrays and to run the current simulation for each of the N searches in lock step. Simulations can walk the trees in parallel using gather ops based on indices pulled from the arrays of child/parent pointers. Actions are selected via argmax over the appropriate axis of an array storing per-action scores for each tree node. You can handle variable length of the current simulation (for the N trees we're growing in parallel) by introducing a "null" node that acts as an absorbing state for the tree walking procedure that finds which node the current simluation will expand in each of the N trees.

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/Mononofu/mcts_benchmarks/issues/1?email_source=notifications&email_token=AAAWR6ECR27L3SIW5XPIZVTRCG3LRA5CNFSM4KSVFQE2YY3PNVWWK3TUL52HS4DFUVEXG43VMWVGG33NNVSW45C7NFSM4IML3LCQ, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAWR6FMWP5Q3HCJJGIUNOTRCG3LRANCNFSM4KSVFQEQ .

Philip-Bachman commented 4 years ago

Yeah, the batched/vectorized MCTS was helpful for a project I'm currently involved in. We're working to reproduce some parts of MuZero and the initial implementation was bottlenecking on the tree walking part of the MCTS simulations, rather than on the dynamics model evaluations as one would expect. Vectorizing the MCTS shifted the bottleneck to the dynamics model evaluations.

I agree that the core tree search logic gets a bit lost in the vectorized version. Though, with some nice comments it makes a nice example of non-trivial vectorization.

We started with a pure Python port of the pseudocode you provided for the paper, so thanks for that.