EdanToledo / Stoix

🏛️A research-friendly codebase for fast experimentation of single-agent reinforcement learning in JAX • End-to-End JAX RL
Apache License 2.0
248 stars 25 forks source link

[BUG] Jax - Flax compatibility error #98

Open thomashirtz opened 4 months ago

thomashirtz commented 4 months ago

Describe the bug

Hello!

When making the Dockerfile, I get the error Cannot import name 'linear_util' from 'jax' when running examples. This seems to be due to the incompatibility of flax with jax. https://stackoverflow.com/questions/78210393/cannot-import-name-linear-util-from-jax (I do get access to my GPU 2070MaxQ with those settings)

I therefore tried to install the version 4.24 by changing requirements.txt from jax>=0.4.10 to jax>=0.4.24 and the Dockerfile line 36 to :

RUN if [ "$USE_CUDA" = true ] ; \
    then pip install "jax[cuda11]>=0.4.24" -f "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ; \
    fi

however I get the error, not being able to use my gpu anymore :

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

Do you have any idea how to solve that ?

Full traceback:

Traceback (most recent call last):
  File "/opt/project/xposure/stoix/systems/q_learning/ff_ddqn.py", line 6, in <module>
    import flashbax as fbx
  File "/xposure/lib/python3.10/site-packages/flashbax/__init__.py", line 16, in <module>
    from flashbax.buffers import (
  File "/xposure/lib/python3.10/site-packages/flashbax/buffers/__init__.py", line 16, in <module>
    from flashbax.buffers.prioritised_flat_buffer import make_prioritised_flat_buffer
  File "/xposure/lib/python3.10/site-packages/flashbax/buffers/prioritised_flat_buffer.py", line 25, in <module>
    from flashbax.buffers.prioritised_trajectory_buffer import (
  File "/xposure/lib/python3.10/site-packages/flashbax/buffers/prioritised_trajectory_buffer.py", line 39, in <module>
    from flashbax.buffers import sum_tree, trajectory_buffer
  File "/xposure/lib/python3.10/site-packages/flashbax/buffers/sum_tree.py", line 33, in <module>
    from flax.struct import dataclass
  File "/xposure/lib/python3.10/site-packages/flax/__init__.py", line 24, in <module>
    from flax import core
  File "/xposure/lib/python3.10/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast as broadcast
  File "/xposure/lib/python3.10/site-packages/flax/core/axes_scan.py", line 23, in <module>
    from jax.extend import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax.extend' (/xposure/lib/python3.10/site-packages/jax/extend/__init__.py)

To Reproduce

Steps to reproduce the behavior:

  1. make build
  2. run ff_ddqn.py with the docker

Possible Solution

Change version of flax and jax/jaxlib in the requirements.txt and the Dockerfile

Context (Environment)

Linux 24.04 with docker. This is the pip freeze if I run the Docker with the current setting of the repo:

absl-py==2.1.0
antlr4-python3-runtime==4.9.3
arch==7.0.0
arrow==1.3.0
attrs==23.2.0
black==24.4.2
blinker==1.8.2
boto3==1.34.140
botocore==1.34.140
bravado==11.0.3
bravado-core==6.1.1
brax==0.10.5
certifi==2024.7.4
cfgv==3.4.0
charset-normalizer==3.3.2
chex==0.1.86
click==8.1.7
cloudpickle==3.0.0
colorama==0.4.6
colorcet==3.0.0
contextlib2==21.6.0
contourpy==1.2.1
craftax==1.4.3
cycler==0.12.1
decorator==5.1.1
distlib==0.3.8
distrax @ git+https://github.com/google-deepmind/distrax@0e449826b6be7603a56b98dbf64873cae3aa523e
dm-env==1.6
dm-tree==0.1.8
docker-pycreds==0.4.0
dotmap==1.3.30
etils==1.7.0
evosax==0.1.6
Farama-Notifications==0.0.4
filelock==3.15.4
flashbax @ git+https://github.com/instadeepai/flashbax@1c31b526e6374620395633d1699494f104543177
Flask==3.0.3
Flask-Cors==4.0.1
flax==0.8.5
fonttools==4.53.1
fqdn==1.5.1
fsspec==2024.6.1
future==1.0.0
gast==0.6.0
gitdb==4.0.11
GitPython==3.1.43
glfw==2.7.0
grpcio==1.64.1
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
gymnax==0.0.8
huggingface-hub==0.23.4
hydra-core==1.3.2
id-marl-eval @ git+https://github.com/instadeepai/marl-eval@f97a72350d954e31a70531f55d8fca50db0d25f0
identify==2.5.36
idna==3.7
imageio==2.34.2
imageio-ffmpeg==0.5.1
importlib-metadata==4.13.0
importlib_resources==6.4.0
isoduration==20.11.0
itsdangerous==2.2.0
jax==0.4.13
jaxlib==0.4.13+cuda11.cudnn86
jaxmarl==0.0.2
jaxopt==0.8.3
Jinja2==3.1.4
jmespath==1.0.1
jsonpointer==3.0.0
jsonref==1.1.0
jsonschema==4.22.0
jsonschema-specifications==2023.12.1
jumanji==1.0.0
kiwisolver==1.4.5
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.5
mctx==0.0.5
mdurl==0.1.2
ml-dtypes==0.4.0
ml_collections==0.1.1
monotonic==1.6
msgpack==1.0.8
mujoco==3.1.6
mujoco-mjx==3.1.6
mypy-extensions==1.0.0
neptune==1.10.4
nest-asyncio==1.6.0
nodeenv==1.9.1
numpy==1.26.4
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvcc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==9.2.0.82
nvidia-cufft-cu11==10.9.0.58
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
oauthlib==3.2.2
omegaconf==2.3.0
opt-einsum==3.3.0
optax @ git+https://github.com/google-deepmind/optax.git@10cf508f505acd99feac5c231c0f521895bb3a37
orbax-checkpoint==0.5.20
packaging==24.1
pandas==1.4.4
param==2.1.1
pathspec==0.12.1
patsy==0.5.6
pgx==2.0.1
pgx-minatar==0.5.1
pillow==10.4.0
platformdirs==4.2.2
pre-commit==3.7.1
protobuf==3.20.2
psutil==6.0.0
pyct==0.5.0
pygame==2.6.0
Pygments==2.18.0
PyJWT==2.8.0
PyOpenGL==3.1.7
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytinyrenderer==0.0.14
pytz==2024.1
PyYAML==6.0.1
referencing==0.35.1
requests==2.32.3
requests-oauthlib==2.0.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.1
rlax==0.1.6
rliable @ git+https://github.com/google-research/rliable@1171833f6706b6c25bbf042e2cb185a96fcf2ce6
rpds-py==0.18.1
s3transfer==0.10.2
safetensors==0.4.3
scipy==1.14.0
seaborn==0.13.2
sentry-sdk==2.7.1
setproctitle==1.3.3
simplejson==3.19.2
six==1.16.0
smmap==5.0.1
statsmodels==0.14.2
svgwrite==1.4.3
swagger-spec-validator==3.0.4
tdqm==0.0.1
tensorboard-logger==0.1.0
tensorboardX==2.6.2.2
tensorflow-probability==0.24.0
tensorstore==0.1.63
tomli==2.0.1
toolz==0.12.1
tqdm==4.66.4
trimesh==4.4.1
types-python-dateutil==2.9.0.20240316
typing_extensions==4.12.2
uri-template==1.3.0
urllib3==2.2.2
virtualenv==20.26.3
wandb==0.17.4
webcolors==24.6.0
websocket-client==1.8.0
Werkzeug==3.0.3
xminigrid @ git+https://github.com/corl-team/xland-minigrid.git@991f13c7885c24c82302a1ee3a68a24a29801a94
-e git+https://github.com/thomashirtz/xposure@af41ff3c6dc2262f7592831efa95a7da505e3b21#egg=xposure
zipp==3.19.2
EdanToledo commented 4 months ago

Hmm, let me look into this. I unfortunately dont have access to a GPU machine currently so itll be hard for me to test this however regardless this reminds me to raise the jax version in the requirements file. Just make sure that the image you are pulling and the jax version has the same cuda and cudnn version and that they are aligned.

EdanToledo commented 4 months ago

@thomashirtz Did you ever figure out the issue?

thomashirtz commented 4 months ago

No, unfortunately I didn't, because I don't have too much time debugging this, I stopped using docker and switch to venv