Fixes a call to jax.numpy.finfo to pass the dtype of the array rather than the array itself, which otherwise causes the following warning:
FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.
jax.tree_util.tree_map
tojax.tree.map
.jax.numpy.finfo
to pass the dtype of the array rather than the array itself, which otherwise causes the following warning: