import jax
import gymnax
rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)
# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")
# Reset the environment.
obs, state = env.reset(key_reset, env_params)
# Sample a random action.
action = env.action_space(env_params).sample(key_act)
# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
I get the following error:
UnfilteredStackTrace: AttributeError: module 'jax' has no attribute 'tree_multimap'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/gymnax/environments/environment.py](https://localhost:8080/#) in step(self, key, state, action, params)
41 obs_re, state_re = self.reset_env(key_reset, params)
42 # Auto-reset environment based on termination
---> 43 state = jax.tree_multimap(
44 lambda x, y: jax.lax.select(done, x, y), state_re, state_st
45 )
AttributeError: module 'jax' has no attribute 'tree_multimap'
I don't think jax.tree_multimap is valid in any Jax version? Isn't it jax.tree_utils.tree_multimap? I might be wrong though, still very new to Jax.
When I try to run the example:
I get the following error:
I don't think jax.tree_multimap is valid in any Jax version? Isn't it jax.tree_utils.tree_multimap? I might be wrong though, still very new to Jax.