Open Alessandro-Castelli opened 3 days ago
Hello @Alessandro-Castelli,
pip install jax==0.4.23 jaxlib==0.4.23 optax==0.2.2
will not modify jax or jaxlib versions.
If you used pip install optax
, the jax and jaxlib versions are automatically bumped (because of the requirements put in optax 2.3). The real culprit is probably not optax but tensorflow that has some maximal versioning requirements that often mess with other packages. In particular tensorflow 2.9.0 does not seem available from pip (see results from running pip index versions tensorflow
on python 3.9.)
Thank you, @vroulet . Yes, maybe the problem lies in the TensorFlow version. How can I fix it?
Available versions: 2.18.0, 2.17.1, 2.17.0, 2.16.2, 2.16.1, 2.15.1, 2.15.0.post1, 2.15.0, 2.14.1, 2.14.0, 2.13.1, 2.13.0, 2.12.1, 2.12.0, 2.11.1, 2.11.0, 2.10.1, 2.10.0, 2.9.3, 2.9.2, 2.9.1, 2.9.0, 2.8.4, 2.8.3, 2.8.2, 2.8.1, 2.8.0, 2.7.4, 2.7.3, 2.7.2, 2.7.1, 2.7.0, 2.6.5, 2.6.4, 2.6.3, 2.6.2, 2.6.1, 2.6.0, 2.5.3, 2.5.2, 2.5.1, 2.5.0
I tried creating two separate conda environments: one where I use TensorFlow to download the dataset and another where I install PennyLane, JAX, JAXlib, and Optax to train the model, but the error still occurs.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. cpu
At this point, I think that the problem is optax.
Name: jax Version: 0.4.23 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy Required-by: chex, optax
Name: jaxlib Version: 0.4.23+cuda11.cudnn86 Summary: XLA library for JAX Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: ml-dtypes, numpy, scipy Required-by: chex, optax
Name: optax Version: 0.1.5 Summary: A gradient processing and optimisation library in JAX. Home-page: https://github.com/deepmind/optax Author: DeepMind Author-email: optax-dev@google.com License: Apache 2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: absl-py, chex, jax, jaxlib, numpy Required-by:
At this point, I think that the problem is optax.
I really don't think so. Just look at the code in optax. It's quite a lightweight library not related to any cuda gpu functionality. There could have been bumped imports but the above version of optax and jax jaxlib seem good.
I cannot reproduce the error you're mentioning as tensorflow 0.9 is not seem available to me locally, and anyway I don't have a gpu. The error clearly points out to jaxlib not optax.
@vroulet I’ll explain why I think it’s optax. Basically, in my initial code, I was using jax and jaxlib 0.4.23, pennylane, and tensorflow 2.9.0, and I didn’t have any issues installing those versions. At some point, I needed more powerful optimizers like Adam to do some tasks, and that’s when I started using optax. Only from that moment, I encountered the issue:
2024-11-28 17:39:56.095832: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. cpu
Maybe it’s optax that doesn’t get along with tensorflow.
Another thing is that when I install optax, it updates jax and jaxlib to version 0.4.30, so maybe that’s the problem. I really don’t know, I’ve tried many combinations of versions but I just can’t get it to work. Do you have a suggestions?
I feel the pain, I've been there - versioning between cuda/jax/tensorflow is a mess.
I would suggest having a different virtualenv for jax-based and TF-based projects if you can ....
Hello @fabianp, I tried to do it, but I think that the real problem is the versioning between JAX and Optax. I tried many different Optax versions, but I didn't resolve my problem.
have you tried installing optax with --no-deps
so it doesn't try to modify the other packages?
Yes, but Optax has additional dependencies, and following this approach doesn't seem to work for Optax.
"I have jax 0.4.23. What happens is that when I install optax with the command
pip install optax
, I get an error message saying 'An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu. [CpuDevice(id=0)]'.This error only occurs after I install optax. What version of optax is compatible with my version of jax and tensorflow 2.9.0?"
Name: jax Version: 0.4.23 Summary: Differentiate, compile, and transform Numpy code. Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: importlib-metadata, ml-dtypes, numpy, opt-einsum, scipy Required-by:
Name: jaxlib Version: 0.4.23+cuda11.cudnn86 Summary: XLA library for JAX Home-page: https://github.com/google/jax Author: JAX team Author-email: jax-dev@google.com License: Apache-2.0 Location: /home/acastelli/miniconda3/envs/af/lib/python3.9/site-packages Requires: ml-dtypes, numpy, scipy Required-by: (af) acastelli@leonardo:/media/HDD/acastelli/test2$ `