google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

Incompatible types error with increased global precision #69

Closed AbhiDu96 closed 11 months ago

AbhiDu96 commented 11 months ago

Hello, Thank you for this amazing library. It's been a while I am using this package for our project. I have started facing the following issue: ''' /home/dubey/anaconda3/envs/jaxenv_qiskit/lib/python3.11/site-packages/jax/_src/ops/scatter.py:93: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=int64 to dtype=int32. In future JAX releases this will result in an error. warnings.warn("scatter inputs have incompatible types: cannot safely cast " ''' I want increased precision for my experiments and so I set the following global variable as: ''' jax.config.update("jax_enable_x64", True) ''' Unfortunately, I have tracked down the bug to gumbel_mu_zero_policy function and within these functions the

  1. action_selection package - Line 140 explicitly sets the dtype to "jnp.int32" which I believe creates the above issue.
  2. search package - Line 172, 288, 364, 367 and 375 also explicitly set the dtype to "int32".

I am not sure how to resolve this issue but kindly let me know if this does not make sense.

Update: I tried keeping "int32" everywhere but as expected that does not make sense since when you specify the global config to x64 it does change precisions of everything.

fidlej commented 11 months ago

Thanks. With jax_enable_x64, the argmax output was int64. I changed the code to use the expected int32.

The implementation is not tested with jax_enable_x64. Check that your results are sensible.

AbhiDu96 commented 11 months ago

Hello @fidlej , Thanks for the solution. With the corresponding changes it works without any warnings or errors. Although I still have to check if the results with tax_enable_x64 are sensible. I think we can close the issue with this. Thanks again.