Remove warning messages by using deprecated jax function.
/home/runner/work/optuna-examples/optuna-examples/haiku/haiku_simple.py:105: DeprecationWarning: jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) or jax.tree_util.tree_leaves (any JAX version).
Description of the changes
By following warning message, replace the deprecated method with jax.tree_util.tree_leaves, which is available in any jax version.
Motivation
Remove warning messages by using deprecated jax function.
Description of the changes
By following warning message, replace the deprecated method with
jax.tree_util.tree_leaves
, which is available in any jax version.