RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
585 stars 54 forks source link

Fix multimap deprecation for notebooks #31

Closed DavidSlayback closed 1 year ago

DavidSlayback commented 1 year ago

Pretty small PR, just replacing jax.tree_ calls with jax.treeutil.tree so that the notebooks work

RobertTLange commented 1 year ago

Thank you! I am pretty sure though that jax.tree_map is still supported though. I only believe multi_map was deprecated. Is there any obvious advantage to using tree_utils? See https://github.com/RobertTLange/gymnax/blob/66828360fe263c057340d014e637960fdc3a5e16/gymnax/environments/environment.py#L43-L45

DavidSlayback commented 1 year ago

Nope, sorry! I kept seeing deprecation warnings on all the other operations and assumed it would be applied to ALL tree operations. Makes more sense this way, tree_map is by far the most commonly-used function