google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 188 forks source link

Basic MCTS example #54

Closed Carbon225 closed 1 year ago

Carbon225 commented 1 year ago

Hi, I really appreciate your library. I will be using it for my thesis project and need to understand how it works.

For this purpose, I used it to implement classic MCTS with random rollouts in a jupyter notebook. https://github.com/Carbon225/mctx-classic I want it to be as informative as possible and to explain why every line is the way it is so my teammates can understand it as well. Feel free to add/ignore this example in your readme.

Below I will describe the last aspect I don't feel like I understand.

Consider this definition of the recurrent_fn:

def recurrent_fn(params, rng_key, action, embedding):
    env = embedding
    env, reward, done = env_step(env, action)
    recurrent_fn_output = mctx.RecurrentFnOutput(
        reward=reward,
        discount=jnp.where(done, 0, -1).astype(jnp.float32),
        prior_logits=my_policy_function(env),
        value=jnp.where(done, 0, my_value_function(env, rng_key)).astype(jnp.float32),
    )
    return recurrent_fn_output, env

I have read that the terminal node is considered absorbing. From my understanding this means that below this node all rewards and values should be 0. This should be guaranteed by setting discount to 0 in the RecurrentFnOutput of the terminal node. In other examples of mctx I have seen people setting the value to 0 at the terminal node as well. Which is correct? When should the reward/value/discount be set to 0?

I also believe the reward field is never actually used by the search. It's only used for training outside the mctx library, correct?

fidlej commented 1 year ago

Thanks for sharing your connect4 code. It is very nicely documented.

And thanks for asking for the clarification. The rewards along a path are summed together to estimate the value of a parent node. I.e., parent_return = reward + discount * child_return https://github.com/deepmind/mctx/blob/aa55375dff40b4ae128680bf6ba2d0874e54fbc3/mctx/_src/search.py#LL270C1-L270C1

To implement an absorbing state, multiple possibilities can achieve the same effect. For example, if not using discount=0, the environment would need to remain in a state with reward=0 and value=end_game_value. Your usage of reward, value, and discount makes sense.

Carbon225 commented 1 year ago

Great! I see now.

Thank you so much, I actually wrote a lot of questions in this text box but finally I understand everything 👍