google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

When I install optax, I am no longer able to use the GPU #1144

Open Alessandro-Castelli opened 3 days ago

Alessandro-Castelli commented 3 days ago

"I have jax 0.4.23. What happens is that when I install optax with the command pip install optax, I get an error message saying 'An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. [CpuDevice(id=0)]'.

This error only occurs after I install optax. What version of optax is compatible with my version of jax and tensorflow 2.9.0?"

Name: jax Version: 0.4.23 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy Required-by:

Name: jaxlib Version: 0.4.23+cuda11.cudnn86 Summary: XLA library for JAX Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: ml-dtypes, numpy, scipy Required-by: (af) acastelli@leonardo:/media/HDD/acastelli/test2$ `

vroulet commented 3 days ago

Hello @Alessandro-Castelli,

pip install jax==0.4.23 jaxlib==0.4.23 optax==0.2.2 will not modify jax or jaxlib versions. If you used pip install optax, the jax and jaxlib versions are automatically bumped (because of the requirements put in optax 2.3). The real culprit is probably not optax but tensorflow that has some maximal versioning requirements that often mess with other packages. In particular tensorflow 2.9.0 does not seem available from pip (see results from running pip index versions tensorflow on python 3.9.)

Alessandro-Castelli commented 3 days ago

Thank you, @vroulet . Yes, maybe the problem lies in the TensorFlow version. How can I fix it?

Available versions: 2.18.0, 2.17.1, 2.17.0, 2.16.2, 2.16.1, 2.15.1, 2.15.0.post1, 2.15.0, 2.14.1, 2.14.0, 2.13.1, 2.13.0, 2.12.1, 2.12.0, 2.11.1, 2.11.0, 2.10.1, 2.10.0, 2.9.3, 2.9.2, 2.9.1, 2.9.0, 2.8.4, 2.8.3, 2.8.2, 2.8.1, 2.8.0, 2.7.4, 2.7.3, 2.7.2, 2.7.1, 2.7.0, 2.6.5, 2.6.4, 2.6.3, 2.6.2, 2.6.1, 2.6.0, 2.5.3, 2.5.2, 2.5.1, 2.5.0

Alessandro-Castelli commented 3 days ago

I tried creating two separate conda environments: one where I use TensorFlow to download the dataset and another where I install PennyLane, JAX, JAXlib, and Optax to train the model, but the error still occurs.

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. cpu

Alessandro-Castelli commented 3 days ago

At this point, I think that the problem is optax.

Name: jax Version: 0.4.23 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy Required-by: chex, optax

Name: jaxlib Version: 0.4.23+cuda11.cudnn86 Summary: XLA library for JAX Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: ml-dtypes, numpy, scipy Required-by: chex, optax

Name: optax Version: 0.1.5 Summary: A gradient processing and optimisation library in JAX. Home-page: https://github.com/deepmind/optax Author: DeepMind Author-email: optax-dev@google.com License: Apache 2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: absl-py, chex, jax, jaxlib, numpy Required-by:

vroulet commented 2 days ago

At this point, I think that the problem is optax.

I really don't think so. Just look at the code in optax. It's quite a lightweight library not related to any cuda gpu functionality. There could have been bumped imports but the above version of optax and jax jaxlib seem good.

I cannot reproduce the error you're mentioning as tensorflow 0.9 is not seem available to me locally, and anyway I don't have a gpu. The error clearly points out to jaxlib not optax.

Alessandro-Castelli commented 2 days ago

@vroulet I’ll explain why I think it’s optax. Basically, in my initial code, I was using jax and jaxlib 0.4.23, pennylane, and tensorflow 2.9.0, and I didn’t have any issues installing those versions. At some point, I needed more powerful optimizers like Adam to do some tasks, and that’s when I started using optax. Only from that moment, I encountered the issue:

2024-11-28 17:39:56.095832: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. cpu

Maybe it’s optax that doesn’t get along with tensorflow.

Another thing is that when I install optax, it updates jax and jaxlib to version 0.4.30, so maybe that’s the problem. I really don’t know, I’ve tried many combinations of versions but I just can’t get it to work. Do you have a suggestions?

fabianp commented 2 days ago

I feel the pain, I've been there - versioning between cuda/jax/tensorflow is a mess.

I would suggest having a different virtualenv for jax-based and TF-based projects if you can ....

Alessandro-Castelli commented 2 days ago

Hello @fabianp, I tried to do it, but I think that the real problem is the versioning between JAX and Optax. I tried many different Optax versions, but I didn't resolve my problem.

fabianp commented 2 days ago

have you tried installing optax with --no-deps so it doesn't try to modify the other packages?

Alessandro-Castelli commented 2 days ago

Yes, but Optax has additional dependencies, and following this approach doesn't seem to work for Optax.