world-modelz / dreamax

A scalable Dreamer implementation in JAX
MIT License
11 stars 2 forks source link

Use `jax.tree_util.tree_map()` instead of `jax.tree_util.tree_multimap()` #15

Open andreaskoepf opened 2 years ago

andreaskoepf commented 2 years ago

Currently the following deprecation warning is printed, which suggests to use jax.tree_util.tree_map() instead of jax.tree_util.tree_multimap():

/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
  warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '