google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 189 forks source link

Understanding RootFnOutput #33

Closed kefirski closed 1 year ago

kefirski commented 1 year ago

Hello!

I'm trying to use mctx library to train MuZero agent and have some troubles with understanding RootFnOutput.

I have two methods .root and .recurrent:

def root(
    self, env_state: Float[Array, "b w1 h1 c1"], train: bool = True
) -> mctx.RootFnOutput:
    initial_state = self.representation(env_state, train)
    policy_logits, value = self.prediction(initial_state, train)
    return mctx.RootFnOutput(
        prior_logits=policy_logits,
        value=value,
        embedding=initial_state,
    )

def recurrent(
    self,
    rng_key,
    action: Int[Array, "b"],
    state: Float[Array, "b w h c"],
    train: bool = True,
) -> Tuple[mctx.RecurrentFnOutput, Float[Array, "b w h c"]]:
    policy_logits, value = self.prediction(state, train)
    new_state, reward = self.dynamics(action, state, train)

    return (
        mctx.RecurrentFnOutput(
            reward=reward,
            discount=jnp.full(reward.shape, self.config.discount),
            prior_logits=policy_logits,
            value=value,
        ),
        new_state,
    )

which I want to use to make a call to .muzero_policy:

initial_state = jax.random.normal(provider(), (1, 84, 84, 6))
root = model.apply(params, initial_state, train=False, method=model.root)

res = mctx.muzero_policy(
    params,
    provider(),
    root=root,
    recurrent_fn=partial(model.apply, method=model.recurrent, train=False),
    num_simulations=4,
)

print(res.search_tree.raw_values)
print(res.search_tree.node_values)

"""
[[0.80960083 0.80960083 2.0920973  2.2743711  2.7579222 ]]
[[3.5309532 3.0515175 3.0072536 2.7471337 2.7579222]]
"""

As a result for raw values I obtain an array with the first two elements being equal. Could you please help me to understand whether this is an intended behaviour or I'm missing something? It feels weird that mctx.RootFnOutput should contain value and prior_logits fields, while the documentation says that it contains the output of a representation network:

@chex.dataclass(frozen=True)
class RootFnOutput:
  """The output of a representation network.

  prior_logits: `[B, num_actions]` the logits produced by a policy network.
  value: `[B]` an approximate value of the current state.
  embedding: `[B, ...]` the inputs to the next `recurrent_fn` call.
  """
  prior_logits: chex.Array
  value: chex.Array
  embedding: RecurrentState
fidlej commented 1 year ago

Thanks for asking. Your root() function looks good. Your recurrent() function should be different. The recurrent() function should return the policy_logits and value after the state-action transition. I.e., self.prediction(new_state, train).

BTW, you may obtain better results by using a different prediction network for the root and the non-root nodes.

kefirski commented 1 year ago

Thanks for answering! Now I see it in the description of RecurrentFnOutput.

BTW, you may obtain better results by using a different prediction network for the root and the non-root nodes.

Thousand of thanks for this tip!

carlosgmartin commented 1 year ago

The recurrent() function should return the policy_logits and value after the state-action transition. I.e., self.prediction(new_state, train).

@fidlej Out of curiosity, why is the field called prior_logits, even though it's the logits after the transition?

fidlej commented 1 year ago

One motivation for the prior_logits name: The prior_logits are produced from the policy network. The prior_logits define a probability distribution before doing search or seeing the Q-values. After the end of the search, the selected action is from a different distribution.

carlosgmartin commented 1 year ago

@fidlej Sorry if this is a silly question, but why does gumbel_muzero_policy's recurrent_fn take an rng_key at all? In the code I've seen, people del this key inside the function or ignore it. Is Gumbel MuZero supposed to correctly handle cases where recurrent_fn is stochastic? Or only deterministic?

fidlej commented 1 year ago

That is a good question. Without chance nodes, MCTS is suitable mainly for deterministic environments. If you have a stochastic environment, I recommend a strong baseline: Use Gumbel MuZero with max_depth=1. The search will then use Q-values q(s, a). The one-step Q-values are valid on stochastic environments.

The rng_key may be still helpful to test the performance on a stochastic environment. Or you can use the rng_key to generate stochastic reward on a bandit.