google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

Issues with Stochastic MuZero #60

Closed carlosgmartin closed 1 year ago

carlosgmartin commented 1 year ago

I'm having issues with mctx.stochastic_muzero_policy. Here's an example:

import jax
import mctx
from jax import numpy as jnp

num_actions = 4
num_chance_outcomes = 2

def decision_recurrent_fn(params, key, action, state):
    return (
        mctx.DecisionRecurrentFnOutput(
            chance_logits=jnp.full(num_chance_outcomes, 0.0),
            afterstate_value=jnp.array(0.0),
        ),
        state,
    )

def chance_recurrent_fn(params, key, action, afterstate):
    return (
        mctx.ChanceRecurrentFnOutput(
            action_logits=jnp.full(num_actions, 0.0),
            value=jnp.array(0.0),
            # reward=jnp.array(1.),
            reward=1 + (action == 0) * 100,
            discount=jnp.array(0.0),
        ),
        afterstate,
    )

def root_fn(state):
    return mctx.RootFnOutput(
        prior_logits=jnp.full(num_actions, 0.0),
        value=jnp.array(0.0),
        embedding=state,
    )

def main():
    root = root_fn(jnp.full(4, 0.0))
    root = jax.tree_map(lambda x: x[None], root)

    key = jax.random.PRNGKey(0)

    output = mctx.stochastic_muzero_policy(
        params=jnp.full(20, 0.0),
        rng_key=key,
        root=root,
        decision_recurrent_fn=jax.vmap(decision_recurrent_fn, [None, None, 0, 0]),
        chance_recurrent_fn=jax.vmap(chance_recurrent_fn, [None, None, 0, 0]),
        num_simulations=1000,
        num_actions=num_actions,
        num_chance_outcomes=num_chance_outcomes,
    )
    assert (output.search_tree.children_rewards == 0).all()
    print(output.action_weights)  # [[0.007 0.451 0.063 0.479]]

if __name__ == "__main__":
    main()

The first issue is that the children_rewards are all 0, despite the fact that chance_recurrent_fn always yields a positive reward.

The second issue is that the final weight of the zeroth action (which receives an additional reward of 100) is not higher than the rest, despite a large number of simulations.

Any idea what might be causing these issues?

fidlej commented 1 year ago

Thanks for sharing the minimal example. I can clear one confusion: The action passed to chance_recurrent_fn(params, key, action, afterstate) is actually the chance outcome. To give different actions different rewards, modify the decision_recurrent_fn to output different afterststate for each action.

You can take an inspiration from the bandit in the tests: https://github.com/deepmind/mctx/blob/bfb7316b96f9e5b04744e8872c1abba9b2dac6b9/mctx/_src/tests/policies_test.py#L42

I will improve the documentation for the chance_recurrent_fn. Sorry for the confusion.

carlosgmartin commented 1 year ago

@fidlej Thanks for your reply. Perhaps the argument can be renamed to outcome, for clarity?

carlosgmartin commented 1 year ago

@fidlej Any idea about the children_rewards issue?

fidlej commented 1 year ago

You can see that the output.search_tree contains only the actions relevant for the decision nodes. The masking is done here: https://github.com/deepmind/mctx/blob/bfb7316b96f9e5b04744e8872c1abba9b2dac6b9/mctx/_src/policies.py#L366

The zeros in the children_rewards then make sense. The reward is zero for the children of the decision nodes.