Farama-Foundation / Gymnasium

An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)
https://gymnasium.farama.org
MIT License
6.84k stars 762 forks source link

[Bug Report] AttributeError: module 'jax.numpy' has no attribute 'DeviceArray' #701

Closed BillHuang2001 closed 1 year ago

BillHuang2001 commented 1 year ago

Describe the bug

Import Gymnasium will result in error:

Traceback (most recent call last):
  ...
  File "*", line 5, in <module>
    import gymnasium as gym
  File "*/venv/lib/python3.10/site-packages/gymnasium/__init__.py", line 13, in <module>
    from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental
  File "*/venv/lib/python3.10/site-packages/gymnasium/experimental/__init__.py", line 4, in <module>
    from gymnasium.experimental import functional, wrappers
  File "*/venv/lib/python3.10/site-packages/gymnasium/experimental/wrappers/__init__.py", line 30, in <module>
    from gymnasium.experimental.wrappers.jax_to_numpy import JaxToNumpyV0
  File "*/venv/lib/python3.10/site-packages/gymnasium/experimental/wrappers/jax_to_numpy.py", line 75, in <module>
    @jax_to_numpy.register(jnp.DeviceArray)
  File "*/venv/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'

JAX has removed the use of DeviceArray in favor of the Array.

Related issue https://github.com/huggingface/transformers/issues/25417

Code example

import gymnasium as gym

System info

Gymnasium installed through pip install OS: Linux, Distro: NixOS with FHS environment Information given by pip show

Name: gymnasium
Version: 0.27.1
Summary: A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym).
Home-page: 
Author: 
Author-email: Farama Foundation <contact@farama.org>
License: MIT License
Location: */venv/lib/python3.10/site-packages
Requires: cloudpickle, gymnasium-notices, jax-jumpy, numpy, typing-extensions
Required-by:

Name: jax
Version: 0.4.14
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: */venv/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: brax, chex, flax, jaxopt, optax, orbax, orbax-checkpoint

Python 3.10.12

Additional context

No response

Checklist

pseudo-rnd-thoughts commented 1 year ago

v0.27 requires jax 0.3 but we have fixed this in v0.29 which requires jax 0.4

Let me know if there is still an issue with v0.29 and 0.4