RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
615 stars 61 forks source link

AttributeError: module 'jax' has no attribute 'tree_multimap' #27

Closed rohan-mehta-1024 closed 2 years ago

rohan-mehta-1024 commented 2 years ago

When I try to run the example:

import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

I get the following error:

UnfilteredStackTrace: AttributeError: module 'jax' has no attribute 'tree_multimap'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/gymnax/environments/environment.py](https://localhost:8080/#) in step(self, key, state, action, params)
     41         obs_re, state_re = self.reset_env(key_reset, params)
     42         # Auto-reset environment based on termination
---> 43         state = jax.tree_multimap(
     44             lambda x, y: jax.lax.select(done, x, y), state_re, state_st
     45         )

AttributeError: module 'jax' has no attribute 'tree_multimap'

I don't think jax.tree_multimap is valid in any Jax version? Isn't it jax.tree_utils.tree_multimap? I might be wrong though, still very new to Jax.

RobertTLange commented 2 years ago

Thank you for bringing this up. It should be fixed in the last release. JAX did support tree_multimap but recently deprecated it. Closing this.