jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.59k stars 2.82k forks source link

[HPC] Missing libdevice.so.10 #22590

Open PhilipVinc opened 4 months ago

PhilipVinc commented 4 months ago

Description

This is probably more of an OpenXLA issue, feel free to close it if so.

The last version of XLA/Jaxlib is shipping without libdevice.so.10 and it looks for it elsewhere.

I am using jax[cuda12_local]==0.4.30 on an HPC cluster. While CUDA is found with no effort on my part (other than module load cuda/12.2.0), the libdevice.so.10 is not detected, and I have to manually specify the path to it by declaring

XLA_FLAGS=--xla_gpu_cuda_data_dir=/gpfslocalsys/cuda/12.2.0/

note that the module I am using is already declaring the following environment paths:

prepend-path    PATH /gpfslocalsys/cuda/12.2.0/bin
prepend-path    PATH /gpfslocalsys/cuda/12.2.0/nvvm/bin
prepend-path    PATH /gpfslocalsys/cuda/12.2.0/samples
prepend-path    LD_LIBRARY_PATH /gpfslocalsys/cuda/12.2.0/targets/x86_64-linux/lib
prepend-path    LD_LIBRARY_PATH /gpfslocalsys/cuda/12.2.0/samples/common/lib/linux/x86_64
prepend-path    LD_LIBRARY_PATH /gpfslocalsys/cuda/12.2.0/lib64
prepend-path    LD_LIBRARY_PATH /gpfslocalsys/cuda/12.2.0/extras/CUPTI/lib64
prepend-path    LD_LIBRARY_PATH /gpfslocalsys/cuda/12.2.0/nvvm/lib64
setenv          CUDA_PATH /gpfslocalsys/cuda/12.2.0
setenv          CUDA_HOME /gpfslocalsys/cuda/12.2.0
setenv          CUDA_ROOT /gpfslocalsys/cuda/12.2.0
setenv          NVHPC_CUDA_HOME /gpfslocalsys/cuda/12.2.0
setenv          CUDA_INSTALL_PATH /gpfslocalsys/cuda/12.2.0
setenv          LIBRARY_PATH /gpfslocalsys/cuda/12.2.0/nvvm/lib64:/gpfslocalsys/cuda/12.2.0/extras/CUPTI/lib64:/gpfslocalsys/cuda/12.2.0/lib64:/gpfslocalsys/cuda/12.2.0/samples/common/lib/linux/x86_64:/gpfslocalsys/cuda/12.2.0/targets/x86_64-linux/lib:/gpfslocalsys/slurm/current/lib/slurm:/gpfslocalsys/slurm/current/lib
prepend-path    LIBRARY_PATH /gpfslocalsys/cuda/12.2.0/lib64/stubs
setenv          C_INCLUDE_PATH /gpfslocalsys/cuda/12.2.0/include
setenv          CPLUS_INCLUDE_PATH /gpfslocalsys/cuda/12.2.0/include

So CUDA_ROOT is correctly defined.

Is it possible that jax is failing to look into some paths? The lib device is located at /gpfslocalsys/cuda/12.2.0/nvvm/libdevice/libdevice.10.bc and without the XLA flag it does not find it.

System info (python version, jaxlib version, accelerator, etc.)

(filippo-test-0429) [udb21rp@jean-zay2: test2]$ pip freeze | grep jax
jax==0.4.30
jax-cuda12-pjrt==0.4.30
jax-cuda12-plugin==0.4.30
jaxlib==0.4.30
jaxtyping==0.2.28
mpi4jax==0.5.1
numba4jax==0.0.12
PhilipVinc commented 2 months ago

@hawkinsp may I bump this? Maybe nobody noticed the issue because it was summer break, but it's quite annoying on HPC hardware that jax/xla cannot find the lib device that is already installed with extra flags.

hawkinsp commented 2 months ago

Sure, I'll take a look.

I guess you're trying to avoid the pip package install, but I suspect if you install the pip package nvidia-cuda-nvcc-cu12 that will work around the problem. I think that package by itself will work. I haven't tried it though.

PhilipVinc commented 2 months ago

Yes. I can’t install the pypi packages because we are using cuda-aware MPI on an HPC cluster which is already linking to some version of cuda provided by the platform…

The fix is trivial: we just do export "XLA_FLAGS=--xla_gpu_cuda_data_dir=$CUDA_ROOT"

But it seems weird that Jax does not realize that it has to look there on its own.

hawkinsp commented 2 months ago

nvidia-cuda-nvcc-cu12 can likely be installed in isolation, since iirc it contains pretty much nothing but ptxas (a statically linked binary) and libdevice.10.bc.

hartikainen commented 1 month ago

I ran into similar issue when installing jax[cuda12] from pypi in a Bazel-based project. The problem in my case is that the _cuda_path-function assumes very specific structure of python site-packages, which, for example, Bazel's rules_python pip installation doesn't comply with. I managed to solve the issue by pulling cuda_nvcc path directly from the nvidia.cuda_nvcc-module and not relative to jaxlib as is currently done.

I'll open a PR for my solution shortly.