Open adamjstewart opened 1 month ago
Hi @adamjstewart GCC compiler is not officially supported by JAX. I recommend using Clang. You can pass the clang path in --clang_path
option.
If you absolutely need to use GCC, we have an experimental support that can be enabled like this:
--bazel_options=--action_env=CUDA_NVCC="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc
I tried adding these flags but I still see the exact same error:
gcc: error: unrecognized command-line option '--cuda-path=external/cuda_nvcc'
Would you paste the full stack trace here please? I'd like to make sure that CUDA_NVCC
value is recognized by Bazel.
Here you go:
Hmm, one more suggestion: try this
``--bazel_options=--action_env=TF_NVCC_CLANG="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc
The reason why your build fails is that GCC is unable to compile CUDA dependencies, it should be done with NVCC compiler.
Still the same issue:
gcc: error: unrecognized command-line option ‘--cuda-path=external/cuda_nvcc’
This is what I've tried:
python3.10 build/build.py --enable_cuda --use_clang=false --bazel_options=--repo_env=CC="/dt9/usr/bin/gcc" --bazel_options=--repo_env=TF_SYSROOT="/dt9" --bazel_options=--action_env=CUDA_NVCC="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc
The subcommand I got:
SUBCOMMAND: # //jaxlib:cpu_feature_guard.so [action 'Compiling jaxlib/cpu_feature_guard.c', configuration: 988f5a730e2bd9c88c71efcc9c7f0d36ad2ec3c5f71c922aabaf7614ff994b0f, execution platform: @local_execution_config_platform//:platform]
(cd /home/ybaturina/.cache/bazel/_bazel_ybaturina/ead9107e8e47a1c42911a02736d63d03/execroot/__main__ && \
exec env - \
CUDA_NVCC=1 \
PATH=/home/kbuilder/.local/bin:/usr/local/bin/python3.10:/home/ybaturina/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin \
PWD=/proc/self/cwd \
external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.d '-frandom-seed=bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.o' '-DBAZEL_CURRENT_REPOSITORY=""' -iquote . -iquote bazel-out/k8-opt/bin -iquote external/python_x86_64-unknown-linux-gnu -iquote bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu -isystem external/python_x86_64-unknown-linux-gnu/include -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include -isystem external/python_x86_64-unknown-linux-gnu/include/python3.10 -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include/python3.10 -isystem external/python_x86_64-unknown-linux-gnu/include/python3.10m -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include/python3.10m -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx -fno-strict-aliasing -fexceptions '-fvisibility=hidden' '--sysroot=/dt9' -c jaxlib/cpu_feature_guard.c -o bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.o)
I didn't get the --cuda_path
option passed to the NVCC compiler.
I assume that something in the environment variables on your machine messes up the subcommand configuration. Since JAX doesn't support GCC compilation officially, I strongly recommend using clang for the compilation.
There is the --cuda-path
issue with GCC for me as well.
Alternatively, I tried to build it with Clang and local CUDA, CUDNN, and NCCL but other issues occure.
In file included from external/xla/xla/tsl/cuda/cudnn_stub.cc:16:
In file included from external/com_google_absl/absl/container/flat_hash_map.h:38:
In file included from external/com_google_absl/absl/algorithm/container.h:43:
In file included from /usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/algorithm:61:
In file included from /usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/bits/stl_algo.h:71:
/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/cstdlib:79:15: fatal error: 'stdlib.h' file not found
79 | #include_next <stdlib.h>
| ^~~~~~~~~~
1 error generated.
Specifically, I run bazel
directly as follows.
build/bazel-6.5.0-linux-x86_64 run --verbose_failures=true \
--repo_env=LOCAL_CUDA_PATH=/opt/cuda \
--repo_env=LOCAL_CUDNN_PATH=/usr \
--repo_env=LOCAL_NCCL_PATH=/usr \
//jaxlib/tools:build_wheel -- \
--output_path=$PWD/dist --cpu=x86_64 \
--jaxlib_git_hash=78ade74d695407306461718a6d73cfed89b4d972
Also, I add the following .bazelrc.user
to the repository root.
# .bazelrc.user
build --strategy=Genrule=standalone
build --action_env CLANG_COMPILER_PATH="/usr/bin/clang-18"
build --repo_env CC="/usr/bin/clang-18"
build --repo_env BAZEL_COMPILER="/usr/bin/clang-18"
build --copt=-Wno-error=unused-command-line-argument
build --copt=-Wno-gnu-offsetof-extensions
build --config=avx_posix
build --config=mkl_open_source_only
build --config=cuda
build --config=nvcc_clang
build --action_env=CLANG_CUDA_COMPILER_PATH=/usr/bin/clang-18
build --repo_env HERMETIC_PYTHON_VERSION="3.12"
Dependency versions follow.
$ pacman -Qs '(cuda|cudnn|clang)'
local/clang 18.1.8-4
C language family frontend for LLVM
local/compiler-rt 18.1.8-1
Compiler runtime libraries for clang
local/cuda 12.6.2-2
NVIDIA's GPU programming toolkit
local/cudnn 9.2.1.18-1
NVIDIA CUDA Deep Neural Network library
This looks like a problem with GCC installation.
If you run clang -v
, then you'll see smth like this:
Selected GCC installation: /usr/bin/../lib/gcc/x86_64-linux-gnu/14
Looking at the error above, I suggest running this command:
find /usr/bin/../include -name "stdlib.h
If the file is not found in GCC v14, that means you'll need to install missing headers and run sudo apt install g++-14
I reproduce the issue for jaxlib
from 0.4.32, 0.4.33, and 0.4.34 with clang-14
and clang-18
(depends on gcc
and gcc-libs
14.2.1+r134+gab884fffe3fc-1). Also, cuda
package depends on gcc-13
(it's Arch).
$ /usr/lib/llvm14/bin/clang-14 -v
clang version 14.0.6
Target: x86_64-pc-linux-gnu
Thread model: posix
InstalledDir: /usr/lib/llvm14/bin
Found candidate GCC installation: /usr/lib/gcc/x86_64-pc-linux-gnu/13.3.0
Found candidate GCC installation: /usr/lib/gcc/x86_64-pc-linux-gnu/14.2.1
Found candidate GCC installation: /usr/lib64/gcc/x86_64-pc-linux-gnu/13.3.0
Found candidate GCC installation: /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1
Selected GCC installation: /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1
Candidate multilib: .;@m64
Candidate multilib: 32;@m32
Selected multilib: .;@m64
Also I have appended -v
option to failed command crosstool_wrapper_driver_is_not_gcc
. It displays system search which have stdlib.h
.
$ (cd ... && .../crosstool_wrapper_driver_is_not_gcc ... -v)
Selected GCC installation: /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1
...
/usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1
/usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/x86_64-pc-linux-gnu
/usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/backward
/usr/lib/llvm14/lib/clang/14.0.6/include
/usr/local/include
/usr/include
End of search list.
external/xla/xla/tsl/cuda/cupti_stub.cc:16:10: fatal error: 'third_party/gpus/cuda/extras/CUPTI/include/cupti.h' file not found
#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1 error generated.
$ ls -l /usr/lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/stdlib.h
-rw-r--r-- 1 root root 2.3K Sep 10 13:07 /usr/lib64/gcc/x86_64-pc-linux- gnu/14.2.1/../../../../include/c++/14.2.1/stdlib.h
But it is a bit odd that now third_party/gpus/cuda/extras/CUPTI/include/cupti.h
is not found. 🤯
There is indeed no directory third_party/gpus
. I didn't find third_party/gpus/cuda/extras/CUPTI
with the command below. 🤯
find -L bazel-jax-jax-v0.4.34 -name 'CUPTI'
UPD Is it upstream issue (XLA)?
Would you check if your local CUDA installation has CUPTI headers please? Specifically, the following headers should be present: https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl#L21-L58
Also please check that the structure of the local CUDA/CUDNN/NCCL dirs is exactly the same as described here.
Sure. I checked and CUPTI is where it should be (i.e. /opt/cuda/extras/CUPTI
; CUDA root is /opt/cuda
). Also I check include directories for CUPTI and they looks perfect.
-Ibazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers
-iquote external/cuda_cupti
-isystem bazel-out/k8-opt/bin/external/cuda_cupti/include
However, all these directories are empty. I compare how these directories are looks for other CUDA library (e.g. cufft
) and they are not empty. Then I manually symlinked include
directory multiple times like
mkdir -p bazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers/third_party/gpus/cuda/extras/CUPTI
ln -s /opt/cuda/extras/CUPTI/include \
bazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers/third_party/gpus/cuda/extras/CUPTI
ln -s /opt/cuda/extras/CUPTI/include \
bazel-out/k8-opt/bin/external/cuda_cupti/include
and run build ... @xla//xla/tsl/cuda:cupti_stub
. Compilation fails.
external/xla/xla/tsl/cuda/BUILD.bazel:240:11: Compiling xla/tsl/cuda/cupti_stub.cc failed: undeclared inclusion(s) in rule '@xla//xla/tsl/cuda:cupti_stub':
this rule is missing dependency declarations for the following files included by 'xla/tsl/cuda/cupti_stub.cc':
'bazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers/third_party/gpus/cuda/extras/CUPTI/include/cupti.h'
'bazel-out/k8-opt/bin/external/cuda_cupti/include/cupti_result.h'
'bazel-out/k8-opt/bin/external/cuda_cupti/include/cupti_version.h'
...
It seems that bazel
does not copy and not recreate header library for cupti
while it has been done for cufft
and others.
Is this trailing slash important? Other BUILD.tpl
are without it. https://github.com/openxla/xla/blob/3740d0854106f32a89687484b05fd8947c89ef91/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl#L60
UPD Manual editing of cuda_cupti.BUILD.tpl
does not work out. 😔
The issue is that /opt/cuda/extras/CUPTI
is not an acceptable location (see here).
This is how CUDA folder should look like:
<LOCAL_CUDA_PATH>/
include/
bin/
lib/
nvvm/
So all headers should be located in <LOCAL_CUDA_PATH>/include
, and all libraries should be in <LOCAL_CUDA_PATH>/lib
.
Also please note that local CUDA installation is not a recommended approach for building from sources.
I have already tried it. I copied everything from extras/CUPTI
to .
but it doesn't help. Moreover, include_prefix = "third_party/gpus/cuda/extras/CUPTI/include"
in cuda_cupti.BUILD.tpl
differs from those in cuda_*.BUILD.tpl
.
include_prefix
corresponds to import prefix in the source files, e.g. this one.
As far as I understand, you use the command below:
build/bazel-6.5.0-linux-x86_64 run --verbose_failures=true \
--repo_env=LOCAL_CUDA_PATH=/opt/cuda \
--repo_env=LOCAL_CUDNN_PATH=/usr \
--repo_env=LOCAL_NCCL_PATH=/usr \
//jaxlib/tools:build_wheel -- \
--output_path=$PWD/dist --cpu=x86_64 \
--jaxlib_git_hash=78ade74d695407306461718a6d73cfed89b4d972
Would you confirm that all CUDA headers are located in /opt/cuda/include
, and all NCCL/CUDNN headers are in /usr/include
? If so, please clean Bazel cache via bazel clean --expunge
and run the command again. If it fails, I would appreciate it if you post the full log here.
Would you confirm that all CUDA headers are located in /opt/cuda/include, and all NCCL/CUDNN headers are in /usr/include?
Absolutely.
If so, please clean Bazel cache via bazel clean --expunge and run the command again. If it fails, I would appreciate it if you post the full log here.
Link.
Since target @xla//xla/tsl/cuda:cudnn_stub
fails first due to missing <stdlib.h>
in this time, I run building @xla//xla/tsl/cuda:cupti_stub
that fails too because of "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" (logs).
build/bazel-6.5.0-linux-x86_64 build --verbose_failures=true \
--repo_env=LOCAL_CUDA_PATH=/opt/cuda \
--repo_env=LOCAL_CUDNN_PATH=/usr \
--repo_env=LOCAL_NCCL_PATH=/usr \
@xla//xla/tsl/cuda:cupti_stub
Can you check this folder please?
/home/bershatsky/.cache/bazel/_bazel_bershatsky/3be6d6eea05ac1cf650a152f41829d38/external/cuda_cupti
Does it have symlink include
pointing to /opt/cuda/include
?
Please don't build @xla//xla/tsl/cuda:cupti_stub
, try @cuda_cupti//:headers
instead - this is the dependency used in Bazel tests for CUDA.
Does it have symlink include pointing to /opt/cuda/include?
Yes, it have include
and others. I checked BUILD
in this directory. All headers are commented (link).
Please don't build @xla//xla/tsl/cuda:cupti_stub, try @cuda_cupti//:headers instead - this is the dependency used in Bazel tests for CUDA.
Target @cuda_cupti//:headers
has been successfully built. Isn't target @xla//xla/tsl/cuda:cupti_stub
a dependency of //jaxlib/tools:build_wheel
?
$ build/bazel-6.5.0-linux-x86_64 query \
--repo_env=LOCAL_CUDA_PATH=/opt/cuda \
--repo_env=LOCAL_CUDNN_PATH=/usr \
--repo_env=LOCAL_NCCL_PATH=/usr \
"deps(kind(rule, deps(//jaxlib/tools:build_wheel)))" | grep cupti
...
@xla//xla/tsl/cuda:cupti_stub
...
I also noticed one important thing: you execute bazel run and bazel build without passing --config=cuda.
I put all auxiliary options to .bazelrc.user
. I believe that this is equivalent.
build --strategy=Genrule=standalone
build --action_env CLANG_COMPILER_PATH="/usr/lib/llvm14/bin/clang-14"
build --repo_env CC="/usr/lib/llvm14/bin/clang-14"
build --repo_env BAZEL_COMPILER="/usr/lib/llvm14/bin/clang-14"
build --copt=-Wno-error=unused-command-line-argument
build --copt=-Wno-gnu-offsetof-extensions
build --config=avx_posix
build --config=mkl_open_source_only
build --config=cuda
build --config=cuda_nvcc
build --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm14/bin/clang-14"
build --repo_env HERMETIC_PYTHON_VERSION="3.12"
The headers are commented out in two cases:
1) Bazel command didn't receive the instruction to use --config=cuda
option.
2) the CUDA repository rule was unable to find CUPTI libraries in /opt/cuda/lib
and assumed that CUPTI redistribution is absent, hence commented out the headers.
bazel query
doesn't recognize Bazel options (including those provided in --config=cuda
). To find true dependencies, you can use bazel cquery
):
Here are my results:
bazel cquery --repo_env=LOCAL_CUDA_PATH="/home/ybaturina/cuda" --repo_env=LOCAL_CUDNN_PATH="/home/ybaturina/cudnn" --repo_env=LOCAL_NCCL_PATH="/home/ybaturina/Downloads/dists/nvidia/nccl" --repo_env=HERMETIC_PYTHON_VERSION=3.10 --config=cuda 'somepath(//jaxlib/tools:build_wheel, @xla//xla/tsl/cuda:cupti_stub)'
- returns nothing
bazel cquery --repo_env=LOCAL_CUDA_PATH="/home/ybaturina/cuda" --repo_env=LOCAL_CUDNN_PATH="/home/ybaturina/cudnn" --repo_env=LOCAL_NCCL_PATH="/home/ybaturina/Downloads/dists/nvidia/nccl" --repo_env=HERMETIC_PYTHON_VERSION=3.10 --config=cuda 'somepath(//jaxlib/tools:build_wheel, @cuda_cupti//:headers)'
- returns the result below:
INFO: Found 2 targets...
//jaxlib/tools:build_wheel (e356dec)
//jaxlib/cuda:cuda_gpu_support (e356dec)
//jaxlib/mosaic/gpu:mosaic_gpu (e356dec)
//jaxlib/mosaic/gpu:_mosaic_gpu_ext (e356dec)
//jaxlib/mosaic/gpu:_mosaic_gpu_ext.so (e356dec)
//jaxlib/cuda:cuda_vendor (e356dec)
@xla//xla/tsl/cuda:cupti (e356dec)
@cuda_cupti//:cupti (e356dec)
@cuda_cupti//:cupti_shared_library (e356dec)
@cuda_cupti//:headers (e356dec)
It seems that missing header error is caused by #include_next
GNU extension and ordering of -isystem
search paths in Bazel(?). Actual search path ordering follows.
...
#include <...> search starts here:
...
/usr/include/c++/14.2.1
/usr/include/c++/14.2.1/x86_64-pc-linux-gnu
/usr/include/c++/14.2.1/backward
/usr/lib/llvm14/lib/clang/14.0.6/include
/usr/local/include
End of search list.
And #include_next
directive passes inclusion of stdlib.h
to the next match which is supposedly stdlib.h
in /usr/local
. But /usr/local
is not in the search list. The list of available stdlib.h
in the system.
$ find /usr -iname stdlib.h
/usr/include/bits/stdlib.h
/usr/include/c++/14.2.1/stdlib.h
/usr/include/c++/14.2.1/tr1/stdlib.h
/usr/include/stdlib.h
No idea how to easily fix the issue. Adding --cxxopt=-isystem/usr/include
to build options does not help. It seems also that bazel
sorts search paths alphabetically.
when i running the command on pretty much default arch linux box,
JAXLIB_RELEASE=$pkgver python build/build.py \
--bazel_startup_options="--output_user_root=$srcdir/bazel"\
--bazel_options='--action_env=JAXLIB_RELEASE' \
--enable_cuda \
--target_cpu_features=release
i don't have a problem with locating stdlib header but have another problem related to it:
ERROR: /home/lie/.cache/pikaur/build/python-jaxlib-cuda/src/jax-jaxlib-v0.4.32/jaxlib/cuda/BUILD:75:13: Compiling jaxlib/gpu/make_batch_pointers.cu.cc failed: (Exit 2): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //jaxlib/cuda:cuda_make_batch_pointers)
(cd /home/lie/.cache/pikaur/build/python-jaxlib-cuda/src/bazel/8a8ca1fd42886b1189093a9473f4da62/execroot/__main__ && \
exec env - \
CLANG_COMPILER_PATH=/usr/bin/clang-18 \
CLANG_CUDA_COMPILER_PATH=/usr/bin/clang-18 \
CUDA_TOOLKIT_PATH=/opt/cuda \
GCC_HOST_COMPILER_PATH=/usr/bin/gcc-13 \
JAXLIB_RELEASE=0.4.32 \
NCCL_INSTALL_PATH=/usr \
PATH=/usr/local/sbin:/usr/local/bin:/usr/bin \
PWD=/proc/self/cwd \
TF_CUDA_COMPUTE_CAPABILITIES=sm_70,sm_75,sm_80,sm_86,sm_89,sm_90,compute_90 \
TF_CUDA_PATHS=/opt/cuda,/usr/lib,/usr \
TF_NVCC_CLANG=1 \
external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/jaxlib/cuda/_objs/cuda_make_batch_pointers/make_batch_pointers.cu.pic.d '-frandom-seed=bazel-out/k8-opt/bin/jaxlib/cuda/_objs/cuda_make_batch_pointers/make_batch_pointers.cu.pic.o' '-DEIGEN_MAX_ALIGN_BYTES=64' -DEIGEN_ALLOW_UNALIGNED_SCALARS '-DEIGEN_USE_AVX512_GEMM_KERNELS=0' '-DJAX_GPU_CUDA=1' '-DBAZEL_CURRENT_REPOSITORY=""' -iquote . -iquote bazel-out/k8-opt/bin -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -iquote external/cuda_cudart -iquote bazel-out/k8-opt/bin/external/cuda_cudart -iquote external/cuda_cublas -iquote bazel-out/k8-opt/bin/external/cuda_cublas -iquote external/cuda_cccl -iquote bazel-out/k8-opt/bin/external/cuda_cccl -iquote external/cuda_nvtx -iquote bazel-out/k8-opt/bin/external/cuda_nvtx -iquote external/cuda_nvcc -iquote bazel-out/k8-opt/bin/external/cuda_nvcc -iquote external/cuda_cusolver -iquote bazel-out/k8-opt/bin/external/cuda_cusolver -iquote external/cuda_cufft -iquote bazel-out/k8-opt/bin/external/cuda_cufft -iquote external/cuda_cusparse -iquote bazel-out/k8-opt/bin/external/cuda_cusparse -iquote external/cuda_curand -iquote bazel-out/k8-opt/bin/external/cuda_curand -iquote external/cuda_cupti -iquote bazel-out/k8-opt/bin/external/cuda_cupti -iquote external/cuda_nvml -iquote bazel-out/k8-opt/bin/external/cuda_nvml -iquote external/cuda_nvjitlink -iquote bazel-out/k8-opt/bin/external/cuda_nvjitlink -iquote external/cuda_cudnn -iquote bazel-out/k8-opt/bin/external/cuda_cudnn -iquote external/xla -iquote bazel-out/k8-opt/bin/external/xla -iquote external/tsl -iquote bazel-out/k8-opt/bin/external/tsl -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/ml_dtypes -iquote bazel-out/k8-opt/bin/external/ml_dtypes -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/local_config_rocm -iquote bazel-out/k8-opt/bin/external/local_config_rocm -iquote external/local_config_tensorrt -iquote bazel-out/k8-opt/bin/external/local_config_tensorrt -iquote external/nccl_archive -iquote bazel-out/k8-opt/bin/external/nccl_archive -Ibazel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers -Ibazel-out/k8-opt/bin/external/cuda_cudart/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cublas/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cccl/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvtx/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvcc/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cusolver/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cufft/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cusparse/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_curand/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cupti/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvml/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_nvjitlink/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/cuda_cudnn/_virtual_includes/headers -Ibazel-out/k8-opt/bin/external/ml_dtypes/_virtual_includes/float8 -Ibazel-out/k8-opt/bin/external/ml_dtypes/_virtual_includes/intn -Ibazel-out/k8-opt/bin/external/local_config_tensorrt/_virtual_includes/tensorrt_headers -Ibazel-out/k8-opt/bin/external/nccl_archive/_virtual_includes/nccl_config -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/cuda_cudart/include -isystem bazel-out/k8-opt/bin/external/cuda_cudart/include -isystem external/cuda_cublas/include -isystem bazel-out/k8-opt/bin/external/cuda_cublas/include -isystem external/cuda_cccl/include -isystem bazel-out/k8-opt/bin/external/cuda_cccl/include -isystem external/cuda_nvtx/include -isystem bazel-out/k8-opt/bin/external/cuda_nvtx/include -isystem external/cuda_nvcc/include -isystem bazel-out/k8-opt/bin/external/cuda_nvcc/include -isystem external/cuda_cusolver/include -isystem bazel-out/k8-opt/bin/external/cuda_cusolver/include -isystem external/cuda_cufft/include -isystem bazel-out/k8-opt/bin/external/cuda_cufft/include -isystem external/cuda_cusparse/include -isystem bazel-out/k8-opt/bin/external/cuda_cusparse/include -isystem external/cuda_curand/include -isystem bazel-out/k8-opt/bin/external/cuda_curand/include -isystem external/cuda_cupti/include -isystem bazel-out/k8-opt/bin/external/cuda_cupti/include -isystem external/cuda_nvml/include -isystem bazel-out/k8-opt/bin/external/cuda_nvml/include -isystem external/cuda_nvjitlink/include -isystem bazel-out/k8-opt/bin/external/cuda_nvjitlink/include -isystem external/cuda_cudnn/include -isystem bazel-out/k8-opt/bin/external/cuda_cudnn/include -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/eigen_archive/mkl_include -isystem bazel-out/k8-opt/bin/external/eigen_archive/mkl_include -isystem external/ml_dtypes -isystem bazel-out/k8-opt/bin/external/ml_dtypes -isystem external/ml_dtypes/ml_dtypes -isystem bazel-out/k8-opt/bin/external/ml_dtypes/ml_dtypes -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/local_config_rocm/rocm -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm -isystem external/local_config_rocm/rocm/rocm/include -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include -isystem external/local_config_rocm/rocm/rocm/include/rocrand -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/rocrand -isystem external/local_config_rocm/rocm/rocm/include/roctracer -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/roctracer -fmerge-all-constants -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '--cuda-path=external/cuda_nvcc' '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' '-Wno-error=unused-command-line-argument' -Wno-gnu-offsetof-extensions -mavx -Wno-gnu-offsetof-extensions -Qunused-arguments '-std=c++17' -x cuda '-DGOOGLE_CUDA=1' '--no-cuda-include-ptx=all' '--cuda-gpu-arch=sm_50' '--cuda-gpu-arch=sm_60' '--cuda-gpu-arch=sm_70' '--cuda-gpu-arch=sm_80' '--cuda-include-ptx=sm_90' '--cuda-gpu-arch=sm_90' '-Xcuda-fatbinary=--compress-all' '-nvcc_options=expt-relaxed-constexpr' -c jaxlib/gpu/make_batch_pointers.cu.cc -o bazel-out/k8-opt/bin/jaxlib/cuda/_objs/cuda_make_batch_pointers/make_batch_pointers.cu.pic.o)
# Configuration: 40116f3bac97303e7dcbac2b0176d8c2300ec77420ba81e4743a9a16e63d74ec
# Execution platform: @local_execution_config_platform//:platform
/home/lie/.cache/pikaur/build/python-jaxlib-cuda/src/bazel/8a8ca1fd42886b1189093a9473f4da62/execroot/__main__/external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc:225: SyntaxWarning: invalid escape sequence '\.'
re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
/usr/include/bits/stdlib.h(37): error: linkage specification is incompatible with previous "realpath" (declared at line 940 of /usr/include/stdlib.h)
realpath (const char *__restrict __name, char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __resolved) noexcept (true)
^
/usr/include/bits/stdlib.h(72): error: linkage specification is incompatible with previous "ptsname_r" (declared at line 1134 of /usr/include/stdlib.h)
ptsname_r (int __fd, char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen) noexcept (true)
^
/usr/include/bits/stdlib.h(91): error: linkage specification is incompatible with previous "wctomb" (declared at line 1069 of /usr/include/stdlib.h)
wctomb (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __s, wchar_t __wchar) noexcept (true)
^
/usr/include/bits/stdlib.h(129): error: linkage specification is incompatible with previous "mbstowcs" (declared at line 1073 of /usr/include/stdlib.h)
mbstowcs (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const char *__restrict __src, size_t __len) noexcept (true)
^
/usr/include/bits/stdlib.h(159): error: linkage specification is incompatible with previous "wcstombs" (declared at line 1077 of /usr/include/stdlib.h)
wcstombs (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const wchar_t *__restrict __src, size_t __len) noexcept (true)
^
/usr/include/bits/string_fortified.h(77): error: linkage specification is incompatible with previous "strcpy" (declared at line 141 of /usr/include/string.h)
strcpy (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src) noexcept (true)
^
/usr/include/bits/string_fortified.h(86): error: linkage specification is incompatible with previous "stpcpy" (declared at line 491 of /usr/include/string.h)
stpcpy (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src) noexcept (true)
^
/usr/include/bits/string_fortified.h(96): error: linkage specification is incompatible with previous "strncpy" (declared at line 144 of /usr/include/string.h)
strncpy (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src, size_t __len) noexcept (true)
^
/usr/include/bits/string_fortified.h(107): error: linkage specification is incompatible with previous "stpncpy" (declared at line 499 of /usr/include/string.h)
stpncpy (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__src, size_t __n) noexcept (true)
^
/usr/include/bits/string_fortified.h(136): error: linkage specification is incompatible with previous "strcat" (declared at line 149 of /usr/include/string.h)
strcat (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src) noexcept (true)
^
/usr/include/bits/string_fortified.h(145): error: linkage specification is incompatible with previous "strncat" (declared at line 152 of /usr/include/string.h)
strncat (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src, size_t __len) noexcept (true)
^
/usr/include/bits/string_fortified.h(161): error: linkage specification is incompatible with previous "strlcpy" (declared at line 506 of /usr/include/string.h)
strlcpy (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src, size_t __n) noexcept (true)
^
/usr/include/bits/string_fortified.h(179): error: linkage specification is incompatible with previous "strlcat" (declared at line 512 of /usr/include/string.h)
strlcat (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const char *__restrict __src, size_t __n) noexcept (true)
^
/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/tuple(2962): error: type name is not allowed
static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
^
/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/tuple(2962): error: type name is not allowed
static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
^
/usr/bin/../lib64/gcc/x86_64-pc-linux-gnu/14.2.1/../../../../include/c++/14.2.1/tuple(2962): error: identifier "__reference_constructs_from_temporary" is undefined
static_assert(!__reference_constructs_from_temporary(_Tp, _Elt));
^
/usr/include/bits/wchar2.h(24): error: linkage specification is incompatible with previous "wmemcpy" (declared at line 287 of /usr/include/wchar.h)
wmemcpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s1, const wchar_t *__restrict __s2, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(36): error: linkage specification is incompatible with previous "wmemmove" (declared at line 292 of /usr/include/wchar.h)
wmemmove (wchar_t * const __attribute__ ((__pass_object_size__ (1 > 1))) __s1, const wchar_t *__s2, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(49): error: linkage specification is incompatible with previous "wmempcpy" (declared at line 301 of /usr/include/wchar.h)
wmempcpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s1, const wchar_t *__restrict __s2, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(62): error: linkage specification is incompatible with previous "wmemset" (declared at line 296 of /usr/include/wchar.h)
wmemset (wchar_t * const __attribute__ ((__pass_object_size__ (1 > 1))) __s, wchar_t __c, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(74): error: linkage specification is incompatible with previous "wcscpy" (declared at line 98 of /usr/include/wchar.h)
wcscpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src) noexcept (true)
^
/usr/include/bits/wchar2.h(84): error: linkage specification is incompatible with previous "wcpcpy" (declared at line 689 of /usr/include/wchar.h)
wcpcpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src) noexcept (true)
^
/usr/include/bits/wchar2.h(94): error: linkage specification is incompatible with previous "wcsncpy" (declared at line 103 of /usr/include/wchar.h)
wcsncpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(106): error: linkage specification is incompatible with previous "wcpncpy" (declared at line 694 of /usr/include/wchar.h)
wcpncpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(118): error: linkage specification is incompatible with previous "wcscat" (declared at line 121 of /usr/include/wchar.h)
wcscat (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src) noexcept (true)
^
/usr/include/bits/wchar2.h(128): error: linkage specification is incompatible with previous "wcsncat" (declared at line 125 of /usr/include/wchar.h)
wcsncat (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(139): error: linkage specification is incompatible with previous "wcslcpy" (declared at line 109 of /usr/include/wchar.h)
wcslcpy (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(155): error: linkage specification is incompatible with previous "wcslcat" (declared at line 115 of /usr/include/wchar.h)
wcslcat (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dest, const wchar_t *__restrict __src, size_t __n) noexcept (true)
^
/usr/include/bits/wchar2.h(254): error: linkage specification is incompatible with previous "fgetws" (declared at line 964 of /usr/include/wchar.h)
fgetws (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, int __n,
^
/usr/include/bits/wchar2.h(272): error: linkage specification is incompatible with previous "fgetws_unlocked" (declared at line 1026 of /usr/include/wchar.h)
fgetws_unlocked (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s,
^
/usr/include/bits/wchar2.h(291): error: linkage specification is incompatible with previous "wcrtomb" (declared at line 326 of /usr/include/wchar.h)
wcrtomb (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, wchar_t __wchar, mbstate_t *__restrict __ps) noexcept (true)
^
/usr/include/bits/wchar2.h(308): error: linkage specification is incompatible with previous "mbsrtowcs" (declared at line 362 of /usr/include/wchar.h)
mbsrtowcs (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const char **__restrict __src, size_t __len, mbstate_t *__restrict __ps) noexcept (true)
^
/usr/include/bits/wchar2.h(321): error: linkage specification is incompatible with previous "wcsrtombs" (declared at line 368 of /usr/include/wchar.h)
wcsrtombs (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const wchar_t **__restrict __src, size_t __len, mbstate_t *__restrict __ps) noexcept (true)
^
/usr/include/bits/wchar2.h(336): error: linkage specification is incompatible with previous "mbsnrtowcs" (declared at line 376 of /usr/include/wchar.h)
mbsnrtowcs (wchar_t * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const char **__restrict __src, size_t __nmc, size_t __len, mbstate_t *__restrict __ps) noexcept (true)
^
/usr/include/bits/wchar2.h(349): error: linkage specification is incompatible with previous "wcsnrtombs" (declared at line 382 of /usr/include/wchar.h)
wcsnrtombs (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __dst, const wchar_t **__restrict __src, size_t __nwc, size_t __len, mbstate_t *__restrict __ps) noexcept (true)
^
/usr/include/bits/unistd.h(26): error: linkage specification is incompatible with previous "read" (declared at line 371 of /usr/include/unistd.h)
read (int __fd, void * const __attribute__ ((__pass_object_size__ (0))) __buf, size_t __nbytes)
^
/usr/include/bits/unistd.h(40): error: linkage specification is incompatible with previous "pread" (declared at line 389 of /usr/include/unistd.h)
pread (int __fd, void * const __attribute__ ((__pass_object_size__ (0))) __buf,
^
/usr/include/bits/unistd.h(66): error: linkage specification is incompatible with previous "pread64" (declared at line 422 of /usr/include/unistd.h)
pread64 (int __fd, void * const __attribute__ ((__pass_object_size__ (0))) __buf,
^
/usr/include/bits/unistd.h(81): error: linkage specification is incompatible with previous "readlink" (declared at line 838 of /usr/include/unistd.h)
readlink (const char *__restrict __path, char * __restrict const __attribute__ ((__pass_object_size__ (0))) __buf, size_t __len) noexcept (true)
^
/usr/include/bits/unistd.h(97): error: linkage specification is incompatible with previous "readlinkat" (declared at line 851 of /usr/include/unistd.h)
readlinkat (int __fd, const char *__restrict __path, char * __restrict const __attribute__ ((__pass_object_size__ (0))) __buf, size_t __len) noexcept (true)
^
/usr/include/bits/unistd.h(111): error: linkage specification is incompatible with previous "getcwd" (declared at line 531 of /usr/include/unistd.h)
getcwd (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __size) noexcept (true)
^
/usr/include/bits/unistd.h(124): error: linkage specification is incompatible with previous "getwd" (declared at line 545 of /usr/include/unistd.h)
getwd (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf) noexcept (true)
^
/usr/include/bits/unistd.h(133): error: linkage specification is incompatible with previous "confstr" (declared at line 644 of /usr/include/unistd.h)
confstr (int __name, char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __len) noexcept (true)
^
/usr/include/bits/unistd.h(146): error: linkage specification is incompatible with previous "getgroups" (declared at line 711 of /usr/include/unistd.h)
getgroups (int __size, __gid_t * const __attribute__ ((__pass_object_size__ (1 > 1))) __list) noexcept (true)
^
/usr/include/bits/unistd.h(160): error: linkage specification is incompatible with previous "ttyname_r" (declared at line 803 of /usr/include/unistd.h)
ttyname_r (int __fd, char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen) noexcept (true)
^
/usr/include/bits/unistd.h(175): error: linkage specification is incompatible with previous "getlogin_r" (declared at line 889 of /usr/include/unistd.h)
getlogin_r (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen)
^
/usr/include/bits/unistd.h(189): error: linkage specification is incompatible with previous "gethostname" (declared at line 911 of /usr/include/unistd.h)
gethostname (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen) noexcept (true)
^
/usr/include/bits/unistd.h(204): error: linkage specification is incompatible with previous "getdomainname" (declared at line 930 of /usr/include/unistd.h)
getdomainname (char * const __attribute__ ((__pass_object_size__ (1 > 1))) __buf, size_t __buflen) noexcept (true)
^
/usr/include/bits/stdio2.h(55): error: linkage specification is incompatible with previous "vsprintf" (declared at line 380 of /usr/include/stdio.h)
vsprintf (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, const char *__restrict __fmt, __gnuc_va_list __ap) noexcept (true)
^
/usr/include/bits/stdio2.h(93): error: linkage specification is incompatible with previous "vsnprintf" (declared at line 389 of /usr/include/stdio.h)
vsnprintf (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, size_t __n, const char *__restrict __fmt, __gnuc_va_list __ap) noexcept (true)
^
/usr/include/bits/stdio2.h(305): error: linkage specification is incompatible with previous "fgets" (declared at line 654 of /usr/include/stdio.h)
fgets (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s, int __n,
^
/usr/include/bits/stdio2.h(322): error: linkage specification is incompatible with previous "fread" (declared at line 728 of /usr/include/stdio.h)
fread (void * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __ptr,
^
/usr/include/bits/stdio2.h(342): error: linkage specification is incompatible with previous "fgets_unlocked" (declared at line 677 of /usr/include/stdio.h)
fgets_unlocked (char * __restrict const __attribute__ ((__pass_object_size__ (1 > 1))) __s,
^
/usr/include/bits/stdio2.h(362): error: linkage specification is incompatible with previous "fread_unlocked" (declared at line 756 of /usr/include/stdio.h)
fread_unlocked (void * __restrict const __attribute__ ((__pass_object_size__ (0))) __ptr,
^
54 errors detected in the compilation of "jaxlib/gpu/make_batch_pointers.cu.cc".
Target //jaxlib/tools:build_wheel failed to build
Can you try this command please?
python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12 --clang_path=<absolute clang compiler path> --use_cuda_nvcc=false
Please note that clang_path
should be a real path, not a symlink.
We are planning to update instructions how to build JAX from source.
then i have the same problem as daskol :
clang-18: error: cannot find CUDA installation; provide its path via '--cuda-path', or pass '-nocudainc' to build without CUDA i
ncludes
Target //jaxlib/tools:build_wheel failed to build
if adding:
#--bazel_options='--repo_env=LOCAL_CUDA_PATH=/opt/cuda' \
message about missing "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
i've also tried several hacks, like
mkdir -p 'third_party/gpus/'
ln -s /opt/cuda/ 'third_party/gpus/cuda'
and rebuilding after cleaning bazel cache - but still having either one or another of two above error messages
May I ask you to describe your use case? Why is it necessary to use the local CUDA path in your scenario? Also would you attach the full log for running builds script please?
i thought it's necessary to build with locally installed cuda version, to have api compatibility matching
regarding build log, it's normally showing only the last command which errored, do i need to configure bazel in a way to show the full log instead?
here's the log from one of the runs:
Thank you for the explanation.
There is no need to build with LOCAL_CUDA_PATH
, that option was added for NVIDIA team only: they might use unpublished CUDA redistributions.
To enable the full logs, you need to add -s
option to the bazel
command.
I looked at the Clang command that produces error in this log. Indeed, there is no --cuda-path
parameter in the CLang options. cuda-path
is set via crosstool/BUILD. Can you check this file on your machine please?
/home/lie/.cache/bazel/_bazel_lie/ca5827d19938fa4a99a005b1b7bd341c/external/local_config_cuda/crosstool/BUILD
Does it have cuda_path
defined?
And one more question: is /usr/sbin/clang-18
a symlink or an absolute path?
i'll have to re-run because i already flushed the build/cache directories
so, before re-running i wonder what's the correct way to bypass cuda_path to the underlying clang if LOCAL_CUDA_PATH is not the thing?
LOCAL_CUDA_PATH
purpose is different: this environment variable controls if Bazel should download the redistributions from NVIDIA URLs, or Bazel should use local paths as sources of redistributions.
The cuda-path
parameter should be added automatically by the toolchain (see here), that is why I'm trying to figure out why it's not happening in your scenario. AFAIK, there is no way to override this behavior.
ok, i restarted afresh, without specifying that env and checked the contents of that file, here it is (the log is the same, as nothing changed in the config from the last time):
$ grep cuda_path ~/.cache/bazel/_bazel_lie/ca5827d19938fa4a99a005b1b7bd341c/external/local_config_cuda/crosstool/BUILD
cuda_path = "",
I've created this PR - https://github.com/openxla/xla/pull/19113
Can you verify if it solves your issue please?
You need to clone XLA repository, then checkout the branch test_693735256
, and then pass --bazel_options=--override_repository=xla=<XLA repo path>
to your python build/build.py
command.
thanks, i'll try that later tonight
Hi @actionless did my PR help in your case? If yes, I can submit it.
yup, now it's being detected 👍
$ grep cuda_path ~/.cache/bazel/_bazel_lie/ca5827d19938fa4a99a005b1b7bd341c/external/local_config_cuda/crosstool/BUILD
cuda_path = "external/cuda_nvcc",
and builds further, thanks for your quick fix!
https://github.com/openxla/xla/pull/19113 is merged now.
Description
When building jaxlib with an externally installed copy of CUDA (something required by all package managers and HPC systems), I see the following error:
It's possible I'm passing the wrong flags somewhere. I'm using:
(of course, with ... replaced by the actual paths)
System info (python version, jaxlib version, accelerator, etc.)
Build log