google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 188 forks source link

Computing both decision and chance branches of recurrent function in Stochastic MuZero is slow #79

Closed evanatyourservice closed 10 months ago

evanatyourservice commented 10 months ago

Hello all! Thank you for creating mctx for the community. I added a performance enhancement to my copy of mctx and am wondering if you're interested in adding it to the official repo.

Computing both decision and chance branches in stochastic muzero every expand is useful if the embeddings have different shapes/dtypes, but is otherwise slow and adds a good bit of overhead (especially for high number of simulations or large networks). I modified the recurrent fn to only have to compute one branch each expand if the decision and chance embeddings are the same struct/shape/dtype.

Overall these are my changes:

In base.py another stochastic recurrent state was added that only holds one embedding:

@chex.dataclass(frozen=True)
class StochasticRecurrentStateEfficient:
    embedding: chex.ArrayTree  # [B, ...]
    is_decision_node: chex.Array  # [B]

The modified version of _make_stochastic_recurrent_fn unrolls the batch using scan so lax.cond can be used to only compute one branch:

def _make_stochastic_recurrent_fn_efficient(
    decision_node_fn: base.DecisionRecurrentFn,
    chance_node_fn: base.ChanceRecurrentFn,
    num_actions: int,
    num_chance_outcomes: int,
) -> base.RecurrentFn:
    """Make Stochastic Recurrent Fn."""

    def stochastic_recurrent_fn(
        params: base.Params,
        rng: chex.PRNGKey,
        action_or_chance: base.Action,  # [B]
        state: base.StochasticRecurrentStateEfficient,
    ) -> Tuple[base.RecurrentFnOutput, base.StochasticRecurrentStateEfficient]:
        def decision_node_branch(action_or_chance, state):
            # Internally we assume that there are `A' = A + C` "actions";
            # action_or_chance can take on values in `{0, 1, ..., A' - 1}`,.
            # To interpret it as an action we can leave it as is:
            action = action_or_chance - 0

            # temporary batch dimension
            action = jnp.expand_dims(action, axis=0)
            embedding = jnp.expand_dims(state.embedding, axis=0)

            decision_output, afterstate_embedding = decision_node_fn(
                params, rng, action, embedding
            )

            decision_output = jax.tree_map(
                lambda x: jnp.squeeze(x, axis=0), decision_output
            )
            afterstate_embedding = jax.tree_map(
                lambda x: jnp.squeeze(x, axis=0), afterstate_embedding
            )

            new_state = base.StochasticRecurrentStateEfficient(
                embedding=afterstate_embedding,
                is_decision_node=jnp.logical_not(state.is_decision_node),
            )
            # Outputs from DecisionRecurrentFunction produce chance logits with
            # dim `C`, to respect our internal convention that there are `A' = A + C`
            # "actions" we pad with `A` dummy logits which are ultimately ignored:
            # see `_mask_tree`.
            return (
                base.RecurrentFnOutput(
                    prior_logits=jnp.concatenate(
                        [
                            jnp.full([num_actions], fill_value=-jnp.inf),
                            decision_output.chance_logits,
                        ],
                        axis=-1,
                    ),
                    value=decision_output.afterstate_value,
                    reward=jnp.zeros_like(decision_output.afterstate_value),
                    discount=jnp.ones_like(decision_output.afterstate_value),
                ),
                new_state,
            )

        def chance_node_branch(action_or_chance, state):
            # To interpret it as a chance outcome we subtract num_actions:
            chance_outcome = action_or_chance - num_actions

            # temporary batch dimension
            chance_outcome = jnp.expand_dims(chance_outcome, axis=0)
            embedding = jnp.expand_dims(state.embedding, axis=0)

            chance_output, state_embedding = chance_node_fn(
                params, rng, chance_outcome, embedding
            )

            chance_output = jax.tree_map(
                lambda x: jnp.squeeze(x, axis=0), chance_output
            )
            state_embedding = jax.tree_map(
                lambda x: jnp.squeeze(x, axis=0), state_embedding
            )

            new_state = base.StochasticRecurrentStateEfficient(
                embedding=state_embedding,
                is_decision_node=jnp.logical_not(state.is_decision_node),
            )
            # Outputs from ChanceRecurrentFunction produce action logits with dim `A`,
            # to respect our internal convention that there are `A' = A + C` "actions"
            # we pad with `C` dummy logits which are ultimately ignored: see
            # `_mask_tree`.
            return (
                base.RecurrentFnOutput(
                    prior_logits=jnp.concatenate(
                        [
                            chance_output.action_logits,
                            jnp.full([num_chance_outcomes], fill_value=-jnp.inf),
                        ],
                        axis=-1,
                    ),
                    value=chance_output.value,
                    reward=chance_output.reward,
                    discount=chance_output.discount,
                ),
                new_state,
            )

        def scan_body(_, xs):
            action_or_chance, state = xs
            output, state = jax.lax.cond(
                state.is_decision_node,
                decision_node_branch,
                chance_node_branch,
                action_or_chance,
                state,
            )
            return None, (output, state)

        _, (output, new_state) = jax.lax.scan(
            scan_body, None, (action_or_chance, state)
        )

        return output, new_state

    return stochastic_recurrent_fn

Finally in the policy function there's a chex check to see if we can use the efficient version:

    try:
        chex.assert_trees_all_equal_structs(root.embedding, dummy_afterstate_embedding)
        chex.assert_trees_all_equal_shapes_and_dtypes(
            root.embedding, dummy_afterstate_embedding
        )
    except AssertionError:
        embeddings_same_shape_dtype = False
    else:
        embeddings_same_shape_dtype = True

    if embeddings_same_shape_dtype:
        # 74.05 sec for old version vs 48.37 sec for efficient version
        embedding = base.StochasticRecurrentStateEfficient(
            embedding=root.embedding,
            is_decision_node=jnp.ones([batch_size], dtype=bool),
        )
        make_stochastic_recurrent_fn = _make_stochastic_recurrent_fn_efficient
    else:
        embedding = base.StochasticRecurrentState(
            state_embedding=root.embedding,
            afterstate_embedding=dummy_afterstate_embedding,
            is_decision_node=jnp.ones([batch_size], dtype=bool),
        )
        make_stochastic_recurrent_fn = _make_stochastic_recurrent_fn

    root = root.replace(
        # pad action logits with num_chance_outcomes so dim is A + C
        prior_logits=jnp.concatenate(
            [
                root.prior_logits,
                jnp.full([batch_size, num_chance_outcomes], fill_value=-jnp.inf),
            ],
            axis=-1,
        ),
        # replace embedding with wrapper.
        embedding=embedding,
    )

    # Stochastic MuZero Change: We need to be able to tell if different nodes are
    # decision or chance. This is accomplished by imposing a special structure
    # on the embeddings stored in each node. Each embedding is an instance of
    # StochasticRecurrentState which maintains this information.
    recurrent_fn = make_stochastic_recurrent_fn(
        decision_node_fn=decision_recurrent_fn,
        chance_node_fn=chance_recurrent_fn,
        num_actions=num_actions,
        num_chance_outcomes=num_chance_outcomes,
    )

When the efficient version is used it saves a bit of memory from not having to hold both embeddings, and is faster with a wider performance gap the higher the simulation count or larger the networks. Maybe there is a better option than scan for the unroll, I just know vmap can't be used because cond is converted back to select.

Let me know if you're interested in adding this to mctx and I can make a testing colab to compare performance and would be happy to make a pull request :)

evanatyourservice commented 10 months ago

Unfortunately this doesn't seem to translate well to GPU, even when unrolling the batch using list comprehension instead of scan. I tested it on CPU and it seems to only help in that scenerio. Unless there's something I'm not seeing there may not be a good way to do this without batched cond through XLA :(

fidlej commented 10 months ago

Thanks for sharing your findings. It is true that the existing implementation of Stochastic MuZero is not efficient.

evanatyourservice commented 10 months ago

The not having batched cond is a little rough sometimes lol

I suppose this could be reopened when that comes around

evanatyourservice commented 3 months ago

@fidlej Sorry if you're busy but do you think the newish jax.experimental.sparse api could be useful for this problem?

fidlej commented 3 months ago

Thanks for asking. The multiplication of sparse matrices is probably not a good fit for this conditional computation.