google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 189 forks source link

Irregular action and chance outcome outputs within search with Stochastic MuZero #37

Closed evanatyourservice closed 1 year ago

evanatyourservice commented 1 year ago

Hello! Thank you for recently adding an implementation of stochastic muzero. I was testing it, but it seems chance outcome and action outputs within the search are irregular. I made this test and print the action and chance outcome from within the dynamics functions:

import numpy as np

import jax
import jax.numpy as jnp
import mctx
from mctx import DecisionRecurrentFnOutput, ChanceRecurrentFnOutput, RecurrentFnOutput

num_actions = 3
num_chance_outcomes = 5

def afterstate_pred(afterstate_embedding):
    chance_logits = jnp.zeros([1, num_chance_outcomes]).at[0, 0].set(1.0)
    afterstate_value = jnp.zeros([1])
    return chance_logits, afterstate_value

def pred(embedding):
    policy_logits = jnp.zeros([1, num_actions]).at[0, 0].set(1.0)
    value = jnp.zeros([1])
    return policy_logits, value

def afterstate_dynamics(action, embedding):
    print(f"action: {action}")
    return embedding

def dynamics(chance_outcome, afterstate_embedding):
    print(f"chance_outcome: {chance_outcome}")
    return afterstate_embedding, jnp.zeros([1])

def decision_recurrent_fn(params, rng_key, action, embedding):
    afterstate_embedding = afterstate_dynamics(action, embedding)
    chance_logits, afterstate_value = afterstate_pred(afterstate_embedding)
    decision_recurrent_fn_output = DecisionRecurrentFnOutput(
        chance_logits=chance_logits,
        afterstate_value=afterstate_value,
    )
    return decision_recurrent_fn_output, afterstate_embedding

def chance_recurrent_fn(params, rng_key, chance_outcome, embedding):
    embedding, reward = dynamics(chance_outcome, embedding)
    policy_logits, value = pred(embedding)
    recurrent_fn_output = ChanceRecurrentFnOutput(
        action_logits=policy_logits,
        value=value,
        reward=reward,
        discount=jnp.full_like(reward, 0.99),
    )
    return recurrent_fn_output, embedding

def recurrent_fn(params, rng_key, action, embedding):
    embedding, reward = dynamics(action, embedding)
    policy_logits, value = pred(embedding)
    recurrent_fn_output = RecurrentFnOutput(
        reward=reward,
        discount=jnp.full_like(reward, 0.99),
        prior_logits=policy_logits,
        value=value,
    )
    return recurrent_fn_output, embedding

def stochastic_muzero_policy():
    """Tests that SMZ is equivalent to MZ with a dummy chance function."""
    root = mctx.RootFnOutput(
        prior_logits=jnp.array(
            [
                [-1.0, 0.0, 2.0],
            ]
        ),
        value=jnp.array([0.0]),
        embedding=jnp.zeros([1, 4]),
    )

    num_simulations = 10

    """policy_output = mctx.muzero_policy(
        params=(),
        rng_key=jax.random.PRNGKey(0),
        root=root,
        recurrent_fn=recurrent_fn,
        num_simulations=num_simulations,
        dirichlet_fraction=0.0,
    )"""

    stochastic_policy_output = mctx.stochastic_muzero_policy(
        params=(),
        rng_key=jax.random.PRNGKey(0),
        root=root,
        decision_recurrent_fn=decision_recurrent_fn,
        chance_recurrent_fn=chance_recurrent_fn,
        num_simulations=2 * num_simulations,
        num_actions=num_actions,
        num_chance_outcomes=num_chance_outcomes,
        dirichlet_fraction=0.0,
    )

    """np.testing.assert_array_equal(stochastic_policy_output.action, policy_output.action)

    np.testing.assert_allclose(
        stochastic_policy_output.action_weights, policy_output.action_weights
    )"""

if __name__ == "__main__":
    with jax.disable_jit():
        stochastic_muzero_policy()

but get outputs like this:

action: [0]
action: [2]
chance_outcome: [-1]
action: [3]
chance_outcome: [0]
action: [0]
chance_outcome: [-3]
action: [4]
chance_outcome: [1]
action: [5]
chance_outcome: [2]
action: [6]
chance_outcome: [3]
action: [7]
chance_outcome: [4]
action: [1]
chance_outcome: [-2]
action: [3]
chance_outcome: [0]
action: [1]
chance_outcome: [-2]
action: [2]
chance_outcome: [-1]
action: [0]
chance_outcome: [-3]
action: [0]
chance_outcome: [-3]
action: [0]
chance_outcome: [-3]
action: [0]
chance_outcome: [-3]
action: [3]
chance_outcome: [0]
action: [0]
chance_outcome: [-3]
action: [4]
chance_outcome: [1]
action: [5]
chance_outcome: [2]
action: [3]
chance_outcome: [0]

I haven't delved into the mctx policy yet to see where the problem might arise, but wanted to start with opening an issue. Do you know what might be causing this behavior?

fidlej commented 1 year ago

The output by print() is misleading. Even with disabled jit, the code would print outputs that are discarded. You can see that the output is selected by a jnp.where: https://github.com/deepmind/mctx/blob/main/mctx/_src/policies.py#L482

evanatyourservice commented 1 year ago

Oh I see, sorry thanks for the clarification!

evanatyourservice commented 1 year ago

@fidlej Hi Ivo I have another issue where I'm getting a shape mismatch error within mctx. My network outputs within the recurrent fn are: output_if_decision_node RecurrentFnOutput(reward=Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/1)>, discount=Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/1)>, prior_logits=Traced<ShapedArray(float32[4,35])>with<DynamicJaxprTrace(level=2/1)>, value=Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/1)>) and output_if_chance_node RecurrentFnOutput(reward=Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/1)>, discount=Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/1)>, prior_logits=Traced<ShapedArray(float32[4,35])>with<DynamicJaxprTrace(level=2/1)>, value=Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/1)>) but the output of the mctx stochastic_recurrent_fn after _broadcast_where (line 482 of policies.py) is RecurrentFnOutput(reward=Traced<ShapedArray(float32[4,1,1,4])>with<DynamicJaxprTrace(level=2/1)>, discount=Traced<ShapedArray(float32[4,1,1,4])>with<DynamicJaxprTrace(level=2/1)>, prior_logits=Traced<ShapedArray(float32[4,1,4,35])>with<DynamicJaxprTrace(level=2/1)>, value=Traced<ShapedArray(float32[4,1,1,4])>with<DynamicJaxprTrace(level=2/1)>). So the _broadcast_where adds dimensions, but this trips the chex assert_shape on line 225 of search.py: chex.assert_shape(step.prior_logits, [batch_size, tree.num_actions]).

evanatyourservice commented 1 year ago

Maybe the end of stochastic_recurrent_fn could be:

    def _broadcast_where(decision_leaf, chance_leaf, is_decision_node):
      return jax.tree_util.tree_map(
          lambda x, y: jnp.where(is_decision_node, x, y),
          decision_leaf,
          chance_leaf
      )

    output = jax.vmap(_broadcast_where)(
        output_if_decision_node,
        output_if_chance_node,
        state.is_decision_node)

    return output, new_state

Not sure if it needs to handle more than one batch dim or what though

fidlej commented 1 year ago

Your code with jax.vmap makes sense. The existing code with expanded_is_decision should also work. I do not see a problem there. Do you have a simple example to reproduce the problem?

evanatyourservice commented 1 year ago

This reproduces the error:

import numpy as np

import jax
import jax.numpy as jnp
import mctx
from mctx import DecisionRecurrentFnOutput, ChanceRecurrentFnOutput, RecurrentFnOutput

batch_size = 4
num_actions = 3
num_chance_outcomes = 5

def afterstate_pred(afterstate_embedding):
    chance_logits = jnp.zeros([batch_size, num_chance_outcomes]).at[:, 0].set(1.0)
    afterstate_value = jnp.zeros([batch_size])
    return chance_logits, afterstate_value

def pred(embedding):
    policy_logits = jnp.zeros([batch_size, num_actions]).at[:, 0].set(1.0)
    value = jnp.zeros([batch_size])
    return policy_logits, value

def afterstate_dynamics(action, embedding):
    return embedding

def dynamics(chance_outcome, afterstate_embedding):
    return afterstate_embedding, jnp.zeros([batch_size])

def decision_recurrent_fn(params, rng_key, action, embedding):
    afterstate_embedding = afterstate_dynamics(action, embedding)
    chance_logits, afterstate_value = afterstate_pred(afterstate_embedding)
    decision_recurrent_fn_output = DecisionRecurrentFnOutput(
        chance_logits=chance_logits,
        afterstate_value=afterstate_value,
    )
    return decision_recurrent_fn_output, afterstate_embedding

def chance_recurrent_fn(params, rng_key, chance_outcome, embedding):
    embedding, reward = dynamics(chance_outcome, embedding)
    policy_logits, value = pred(embedding)
    recurrent_fn_output = ChanceRecurrentFnOutput(
        action_logits=policy_logits,
        value=value,
        reward=reward,
        discount=jnp.full_like(reward, 0.99),
    )
    return recurrent_fn_output, embedding

def stochastic_muzero_policy():
    root = mctx.RootFnOutput(
        prior_logits=jax.random.normal(
            jax.random.PRNGKey(0), [batch_size, num_actions]
        ),
        value=jnp.zeros([batch_size]),
        embedding=jnp.zeros([batch_size, 7]),
    )

    num_simulations = 10

    stochastic_policy_output = mctx.stochastic_muzero_policy(
        params=(),
        rng_key=jax.random.PRNGKey(0),
        root=root,
        decision_recurrent_fn=decision_recurrent_fn,
        chance_recurrent_fn=chance_recurrent_fn,
        num_simulations=num_simulations,
        num_actions=num_actions,
        num_chance_outcomes=num_chance_outcomes,
        dirichlet_fraction=0.0,
    )

if __name__ == "__main__":
    stochastic_muzero_policy()

I think it might be line 747, should that be len(decision_leaf.shape) instead of len(decision_leaf)?

fidlej commented 1 year ago

Ah, excellent. Thanks for finding the len(decision_leaf) problem. I will prepare a fix.

evanatyourservice commented 1 year ago

Thank you!

evanatyourservice commented 1 year ago

any ideas for releasing 0.0.3 now that stochastic has been added?