Open PhilipVinc opened 4 months ago
@hawkinsp may I bump this? Maybe nobody noticed the issue because it was summer break, but it's quite annoying on HPC hardware that jax/xla cannot find the lib device that is already installed with extra flags.
Sure, I'll take a look.
I guess you're trying to avoid the pip package install, but I suspect if you install the pip
package nvidia-cuda-nvcc-cu12
that will work around the problem. I think that package by itself will work. I haven't tried it though.
Yes. I can’t install the pypi packages because we are using cuda-aware MPI on an HPC cluster which is already linking to some version of cuda provided by the platform…
The fix is trivial: we just do export "XLA_FLAGS=--xla_gpu_cuda_data_dir=$CUDA_ROOT"
But it seems weird that Jax does not realize that it has to look there on its own.
nvidia-cuda-nvcc-cu12
can likely be installed in isolation, since iirc it contains pretty much nothing but ptxas
(a statically linked binary) and libdevice.10.bc
.
I ran into similar issue when installing jax[cuda12]
from pypi in a Bazel-based project. The problem in my case is that the _cuda_path
-function assumes very specific structure of python site-packages, which, for example, Bazel's rules_python
pip installation doesn't comply with. I managed to solve the issue by pulling cuda_nvcc
path directly from the nvidia.cuda_nvcc
-module and not relative to jaxlib
as is currently done.
I'll open a PR for my solution shortly.
Description
This is probably more of an OpenXLA issue, feel free to close it if so.
The last version of XLA/Jaxlib is shipping without libdevice.so.10 and it looks for it elsewhere.
I am using
jax[cuda12_local]==0.4.30
on an HPC cluster. While CUDA is found with no effort on my part (other thanmodule load cuda/12.2.0
), thelibdevice.so.10
is not detected, and I have to manually specify the path to it by declaringnote that the module I am using is already declaring the following environment paths:
So CUDA_ROOT is correctly defined.
Is it possible that jax is failing to look into some paths? The lib device is located at
/gpfslocalsys/cuda/12.2.0/nvvm/libdevice/libdevice.10.bc
and without the XLA flag it does not find it.System info (python version, jaxlib version, accelerator, etc.)