google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 188 forks source link

Change deprecated jax.tree_util.tree_map to jax.tree.map. Fix argument passed to jax.numpy.finfo call. #95

Closed carlosgmartin closed 2 months ago

carlosgmartin commented 2 months ago
  1. Changes uses of the deprecated function jax.tree_util.tree_map to jax.tree.map.
  2. Fixes a call to jax.numpy.finfo to pass the dtype of the array rather than the array itself, which otherwise causes the following warning:

    FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.