google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

Stochastic MuZero issues invalid actions and outcomes #88

Closed carlosgmartin closed 7 months ago

carlosgmartin commented 7 months ago

Example:

import jax
import mctx
from jax import numpy as jnp, random

def main():
    n_actions = 7
    n_outcomes = 3
    batch_size = 1

    root = mctx.RootFnOutput(  # type: ignore
        prior_logits=jnp.zeros([batch_size, n_actions]),
        value=jnp.zeros(batch_size),
        embedding=jnp.zeros(batch_size),
    )

    def decision_recurrent_fn(params, key, action, state):
        jax.debug.print("action: {}", action)
        afterstate = state
        output = mctx.DecisionRecurrentFnOutput(  # type: ignore
            chance_logits=jnp.zeros([batch_size, n_outcomes]),
            afterstate_value=jnp.zeros(batch_size),
        )
        return output, afterstate

    def chance_recurrent_fn(params, key, outcome, afterstate):
        jax.debug.print("outcome: {}", outcome)
        state = afterstate
        output = mctx.ChanceRecurrentFnOutput(  # type: ignore
            action_logits=jnp.zeros([batch_size, n_actions]),
            value=jnp.zeros(batch_size),
            reward=jnp.zeros(batch_size),
            discount=jnp.ones(batch_size),
        )
        return output, state

    mctx.stochastic_muzero_policy(
        params={},
        rng_key=random.PRNGKey(0),
        root=root,
        decision_recurrent_fn=decision_recurrent_fn,
        chance_recurrent_fn=chance_recurrent_fn,
        num_simulations=20,
    )

if __name__ == "__main__":
    main()

Output:

action: [0]
action: [1]
outcome: [-6]
action: [7]
outcome: [0]
action: [4]
outcome: [-3]
action: [8]
outcome: [1]
action: [5]
outcome: [-2]
action: [6]
outcome: [-1]
action: [0]
outcome: [-7]
action: [3]
outcome: [-4]
action: [2]
outcome: [-5]
action: [9]
outcome: [2]
action: [5]
outcome: [-2]
action: [7]
outcome: [0]
action: [5]
outcome: [-2]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [7]
outcome: [0]
action: [2]
outcome: [-5]
action: [2]
outcome: [-5]

The actions range from 0 to 9 (7+3=10 in total), even though there are only 7 actions. The outcomes range from -7 (a negative integer!) to 2 (7+3=10 in total), even though there are only 3 outcomes.

This may have something to do with the math inside stochastic_recurrent_fn.

Version information:

$ python3 --version
Python 3.11.6
$ python3 -c "import mctx; print(mctx.__version__)"
0.0.5
$ python3 -c "import jax; print(jax.__version__)"
0.4.23
$ python3 -c "import jaxlib; print(jaxlib.__version__)"
0.4.23
evanatyourservice commented 7 months ago

I asked this here, because decision and chance fns alternate the unused portion is discarded (the invalid ones you are seeing wouldn’t be used).

carlosgmartin commented 7 months ago

@evanatyourservice Hmm, doesn't that mean computation is being wasted on the invalid actions and outcomes?

evanatyourservice commented 7 months ago

Yeah, here’s another issue I opened about that. There’s not a super simple workaround right now because jaxlib hasn’t implemented batched cond. Here’s an issue related to that in jax (there are a few). Doesn’t matter the most for mctx if using tpu or gpu because they can be computed for the most part in parallel but the issue is there nonetheless. Becomes a bigger problem when the two branches differ in computation amounts.

Maybe the mcts could be rewritten in a way that only one branch is computed, I didn’t really think into it. If you see a solution please share, I might look into it as sometimes my nets are large and there would be quite the difference in speed/compute.

fidlej commented 7 months ago

Thanks @evanatyourservice for explaining the issue. If you rewrite the MCTS to better support Stochastic MuZero, maybe keep that in your fork. I do not plan big changes to mctx.

carlosgmartin commented 7 months ago

@fidlej According to the JAX docs on out-of-bounds indexing:

Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) will not preserve the semantics of out of bounds indexing. Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.

If out-of-bounds indices being passed into functions that expect valid actions or outcomes could potentially cause undefined behavior somewhere downstream, would it perhaps be a good idea to clip the indices passed to decision_recurrent_fn and chance_recurrent_fn to their valid ranges?

iamreadyi commented 7 months ago

Thanks @evanatyourservice for explaining the issue. If you rewrite the MCTS to better support Stochastic MuZero, maybe keep that in your fork. I do not plan big changes to mctx.

Will you close enhancement issue about stochastic muzero?

fidlej commented 7 months ago

Thanks for noticing. Done.