google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

`muzero_policy` search vs `gumbel_muzero_policy` search performance #93

Closed LeonEricsson closed 2 months ago

LeonEricsson commented 2 months ago

I've seen others reporting that the muzero_policy is slow and I've run into this problem myself so I wanted to add a bit more information. I'm not expecting a solution to this problem, but it might be interesting for the authors to better understand this problem.

muzero_policy is drastically slower than gumbel_muzero_policy, and the problem stems from the body_fun in search.py. I'm not able to identify the exact LOC, but the execution halts to almost a complete stop in this loop. When executing on CPU I see very strange behavior where the first N calls to muzero_policy execute lightning fast, only to come to a complete stop at iteration N+1, taking minutes to complete a search with only 100 simulations. The batch size seems to affect when this halt occurs.

I wish my solution to this problem was as easy as switching to gumbel_muzero_policy but I'm working with stochastic nodes, and we've observed significantly worse performance from gumbel in this setting

LeonEricsson commented 2 months ago

I've identified the root selection function as the culprit. Switching out the muzero's root selection function to gumbel's root selection policy incurs a ~100x speedup in my case.