google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

Changing which actions are invalid based on state #89

Closed wcarvalho closed 7 months ago

wcarvalho commented 7 months ago

Hello, thank you for this great library!

I have a setting where the actions available (actually options) change dynamically. Right now, I am thinking that I can learn when options are available.

I see that the muzero policies (e.g. here) allow for setting invalid_actions but this seems to be permanent throughout the search. If I want to change these functions so that which actions are invalid are a function of the current state, do you have a good sense of where to do this?

Right now I'm thinking that this belows in the expand step of MCTS (i.e. here). Do you think I'm on the right track?

If not, can you help me find a better solution?

Thank you!

lowrollr commented 7 months ago

(not a mctx maintainer)

If valid actions are a function of the environment state, and you have access to this state throughout the simulation (AlphaZero), I would just implement this as part of the recurrent_fn passed to muzero_policy. You can mask out prior logits that are invalid by setting them to a large negative number, and you can store the current environment state in the embedding field. A few of the examples linked in the readme show how to do this.

If you're using a learned environment model (MuZero) it should learn the dynamics that define which actions are invalid/available. Since non-root node states in MuZero are all approximate, there is no need to store/calculate invalid actions.

You can read more about the difference in Appendix A of the MuZero paper: https://arxiv.org/abs/1911.08265

wcarvalho commented 7 months ago

This makes sense! Thank you.