google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 189 forks source link

`_compute_mixed_value` divide by zero problem #22

Closed uduse closed 2 years ago

uduse commented 2 years ago

https://github.com/deepmind/mctx/blob/209b28a62b8819d5736e08b33a496488da7a4807/mctx/_src/qtransforms.py#L191

sum_probs might be zero when none of the children is visited.

fidlej commented 2 years ago

Thanks for mentioning it. Do you see an error? The weighted_q contains another visit_counts > 0 check, so the weighted_q should be correct.

uduse commented 2 years ago

This happened when I used qtransform_completed_by_mix_value with muzero_policy. Running the policy does not raise an error but the search always have all simulations invested in the same action (the model is okay, I verified with the default Q transform). I tried to debug by disabling jax.jit and consequently all values are computed concretely instead of being traced. The expression that has divide by zero error is always evaluated concretely as well so it always raises an error. At least this is debugging unfriendly.

fidlej commented 2 years ago

Thanks for the info. I now created a pull-request #23 to avoid the NaN.

BTW, do not use qtransform_completed_by_mix_value with the muzero_policy. The muzero_policy assumes that the Q-values are from [0, 1]. Use instead the gumbel_muzero_policy.