google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

AlphaZero subtree persistence #86

Closed lowrollr closed 7 months ago

lowrollr commented 8 months ago

Requested by https://github.com/google-deepmind/mctx/issues/51, this PR introduces the capability to pass a Tree to muzero_policy and gumbel_muzero_policy, allowing for MCTS to continue from a pre-initialized tree.

The main use-case is for users implementing AlphaZero, where environment dynamics are known, not modeled and therefore saving work from a previous MCTS call becomes useful.

I introduce a new public API function get_subtree, which extracts a subtree rooted at a given root child index, which can be utilized by AlphaZero-esque implementations to extract the subtree corresponding to a taken action.

I also include a utility function reset_search_tree, which can be used to reset/zero out the search tree, useful in the case of a terminated episode where the search tree can be discarded.

Including this feature within an AlphaZero implementation might look something like this (pseudo-code)

output = mctx.muzero_policy(..., tree=tree)
tree = mctx.get_subtree(output.search_tree, output.action)
terminated = env.step(output.action)
tree = mctx.reset_search_tree(tree, terminated)

In the case where no trees have been initialized mctx.muzero_policy(..., tree=None) still works and will instantiate a new search tree (as before).

I've also decoupled num_simulations from the capacity of the search tree, which is now specified as an argument to muzero_policy or gumbel_muzero_policy called max_nodes. If max_nodes is not specified, the tree capacity defaults to num_simulations (just as it worked before). This is useful in the case of AlphaZero, where the number of occupied nodes in the search tree may grow/shrink from call to call so it's useful to include extra capacity.

I also included tests for get_subtree that run on each of the existing test pytrees. The tests run get_subtree on each of the root children and compare against the source tree. I'd be happy to only run on a subset of the child nodes if test runtime is too long (~60s total on my machine).

Calls the public API work as they did before, I did not introduce any new mandatory arguments. Happy to re-organize & re-tool any of these changes if the maintainers have suggestions.

lowrollr commented 8 months ago

I thought of one concern regarding the Tree property num_simulations. The number of simulations that a particular Tree object supported used to be equivalent to its capacity, but in this PR this is no longer the case, which could make the name of this property deceiving (as it now just tied to capacity, or maximum number of simulations).

fidlej commented 7 months ago

Thanks for trying the get_subtree() and sending the PR. Sorry for my slow response.

I worry that the subtree reuse is not compatible with the current gumbel_muzero_policy implementation. That policy assumes that the tree starts empty. To implement the sequential halving, the action selection uses a simulation_index. https://github.com/google-deepmind/mctx/blob/d40d32e1a18fb73030762bac33819f95fff9787c/mctx/_src/action_selection.py#L145C3-L145C19

lowrollr commented 7 months ago

I see -- I'm not aware of a good way to incorporate any existing visit counts into the sequential halving algorithm, especially given that they were generated by the interior action selection algorithm -- perhaps devising a way to do this would be a good research problem but is probably out of scope for this PR.

I will remove the option for subtree reuse from gumbel_muzero_policy and just allow it for muzero_policy. If you'd prefer, I could instead create a new policy alphazero_policy that allows for subtree reuse and is otherwise identical to muzero_policy and restore muzero_policy to the way it was before. I wanted to minimize changes to the public API but this could help disambiguate.

fidlej commented 7 months ago

Thanks for the comment. Are you sure that the implementation works correctly? I left some comments on the code, but I have not checked everything.

lowrollr commented 7 months ago

Thanks for the comment. Are you sure that the implementation works correctly?

As far as I can tell -- all subtrees of the provided test trees are reproduced accurately in the tests I wrote. I also tested the feature in the Connect 4 example notebook linked in the readme and had no issues.

I'd be happy to write some more granular test cases if you'd like.

I left some comments on the code, but I have not checked everything.

I'm not able to see your comments yet

fidlej commented 7 months ago

Thanks for the clarifications. You understand the code well.

Would it be OK to keep the functionality unmerged? If people want this alphazero-specific functionality, they can look at your repository.

lowrollr commented 7 months ago

You mention AlphaZero in the readme, so in my opinion supporting subtree re-use should be included functionality.

I understand wanting to keep the codebase as lightweight and simple as possible. If you have specific concerns, constraints, or parts of the code you'd prefer be left unchanged I'd be happy to work around them to get this ok to merge.

fidlej commented 7 months ago

I want to ensure that mctx will work correctly and I currently do not have time to carefully review the proposed changes. Mctx will probably remain mostly frozen.

lowrollr commented 7 months ago

That is understandable. In that case I can document the new functionality in my repository and provide a few examples. Would appreciate you adding a link to my repo in 'Example Projects' in the mctx readme when I am done.

Thank you for taking a look at my code, I admire this repo a lot.

fidlej commented 7 months ago

Thank you. When you are ready with your repo, please ping me and I will add the link.

lowrollr commented 7 months ago

Here's the link: https://github.com/lowrollr/mctx-az

@fidlej