Currently the following deprecation warning is printed, which suggests to use jax.tree_util.tree_map() instead of jax.tree_util.tree_multimap():
/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '
Currently the following deprecation warning is printed, which suggests to use
jax.tree_util.tree_map()
instead ofjax.tree_util.tree_multimap()
: