google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 188 forks source link

Explicitly use int32 for the argmax output. #70

Closed copybara-service[bot] closed 1 year ago

copybara-service[bot] commented 1 year ago

Explicitly use int32 for the argmax output.

Otherwise the argmax output would be int64 when using jax_enable_x64.

Fixes #69.