google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

Automatically determine num_actions and num_chance_outcomes in stochastic_muzero_policy #71

Closed carlosgmartin closed 11 months ago

carlosgmartin commented 11 months ago

It is possible to automatically determine the num_actions and num_chance_outcomes parameters to stochastic_muzero_policy from the root, decision_recurrent_fn, and chance_recurrent_fn parameters, via ChanceRecurrentFnOutput.action_logits.shape and DecisionRecurrentFnOutput.chance_logits.shape. That would reduce the number of parameters that users need to pass in by two.

I suggest allowing num_actions and num_chance_outcomes to be None, making them None by default, and determining their values automatically if they're None. This preserves backward compatibility. Thoughts?

fidlej commented 11 months ago

Thanks for the suggestions. It sounds good. We do not need to preserve the backward compatibility with the None defaults. The stochastic_muzero_policy is still experimental.

Will you send a pull request or should I do that?

fidlej commented 11 months ago

Thanks for the improvements.