Closed kefirski closed 2 years ago
Thanks for asking. Your root() function looks good. Your recurrent() function should be different. The recurrent() function should return the policy_logits and value after the state-action transition. I.e., self.prediction(new_state, train).
BTW, you may obtain better results by using a different prediction network for the root and the non-root nodes.
Thanks for answering! Now I see it in the description of RecurrentFnOutput
.
BTW, you may obtain better results by using a different prediction network for the root and the non-root nodes.
Thousand of thanks for this tip!
The recurrent() function should return the policy_logits and value after the state-action transition. I.e., self.prediction(new_state, train).
@fidlej Out of curiosity, why is the field called prior_logits, even though it's the logits after the transition?
One motivation for the prior_logits name: The prior_logits are produced from the policy network. The prior_logits define a probability distribution before doing search or seeing the Q-values. After the end of the search, the selected action is from a different distribution.
@fidlej Sorry if this is a silly question, but why does gumbel_muzero_policy
's recurrent_fn
take an rng_key
at all? In the code I've seen, people del
this key inside the function or ignore it. Is Gumbel MuZero supposed to correctly handle cases where recurrent_fn
is stochastic? Or only deterministic?
That is a good question. Without chance nodes, MCTS is suitable mainly for deterministic environments. If you have a stochastic environment, I recommend a strong baseline: Use Gumbel MuZero with max_depth=1. The search will then use Q-values q(s, a). The one-step Q-values are valid on stochastic environments.
The rng_key may be still helpful to test the performance on a stochastic environment. Or you can use the rng_key to generate stochastic reward on a bandit.
Hello!
I'm trying to use
mctx
library to train MuZero agent and have some troubles with understandingRootFnOutput
.I have two methods
.root
and.recurrent
:which I want to use to make a call to
.muzero_policy
:As a result for raw values I obtain an array with the first two elements being equal. Could you please help me to understand whether this is an intended behaviour or I'm missing something? It feels weird that
mctx.RootFnOutput
should containvalue
andprior_logits
fields, while the documentation says that it contains the output of a representation network: