Traceback (most recent call last):
...
File "*", line 5, in <module>
import gymnasium as gym
File "*/venv/lib/python3.10/site-packages/gymnasium/__init__.py", line 13, in <module>
from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental
File "*/venv/lib/python3.10/site-packages/gymnasium/experimental/__init__.py", line 4, in <module>
from gymnasium.experimental import functional, wrappers
File "*/venv/lib/python3.10/site-packages/gymnasium/experimental/wrappers/__init__.py", line 30, in <module>
from gymnasium.experimental.wrappers.jax_to_numpy import JaxToNumpyV0
File "*/venv/lib/python3.10/site-packages/gymnasium/experimental/wrappers/jax_to_numpy.py", line 75, in <module>
@jax_to_numpy.register(jnp.DeviceArray)
File "*/venv/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'
JAX has removed the use of DeviceArray in favor of the Array.
Describe the bug
Import Gymnasium will result in error:
JAX has removed the use of
DeviceArray
in favor of theArray
.Related issue https://github.com/huggingface/transformers/issues/25417
Code example
System info
Gymnasium installed through pip install OS: Linux, Distro: NixOS with FHS environment Information given by
pip show
Python 3.10.12
Additional context
No response
Checklist