jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.18k stars 2.77k forks source link

jaxlib 0.4.32, external CUDA, gcc: error: unrecognized command-line option '--cuda-path=external/cuda_nvcc' #23689

Open adamjstewart opened 3 weeks ago

adamjstewart commented 3 weeks ago

Description

When building jaxlib with an externally installed copy of CUDA (something required by all package managers and HPC systems), I see the following error:

gcc: error: unrecognized command-line option '--cuda-path=external/cuda_nvcc'

It's possible I'm passing the wrong flags somewhere. I'm using:

> python3 build/build.py --enable_cuda --cuda_compute_capabilities=8.0 --bazel_options=--repo_env=LOCAL_CUDA_PATH=... --bazel_options=--repo_env=LOCAL_CUDNN_PATH=... --bazel_options=--repo_env=LOCAL_NCCL_PATH=...

(of course, with ... replaced by the actual paths)

System info (python version, jaxlib version, accelerator, etc.)

Build log

ybaturina commented 3 weeks ago

Hi @adamjstewart GCC compiler is not officially supported by JAX. I recommend using Clang. You can pass the clang path in --clang_path option.

ybaturina commented 3 weeks ago

If you absolutely need to use GCC, we have an experimental support that can be enabled like this:

--bazel_options=--action_env=CUDA_NVCC="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc

adamjstewart commented 2 weeks ago

I tried adding these flags but I still see the exact same error:

gcc: error: unrecognized command-line option '--cuda-path=external/cuda_nvcc'
ybaturina commented 1 week ago

Would you paste the full stack trace here please? I'd like to make sure that CUDA_NVCC value is recognized by Bazel.

adamjstewart commented 1 week ago

Here you go:

ybaturina commented 1 week ago

Hmm, one more suggestion: try this ``--bazel_options=--action_env=TF_NVCC_CLANG="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc

The reason why your build fails is that GCC is unable to compile CUDA dependencies, it should be done with NVCC compiler.

adamjstewart commented 1 week ago

Still the same issue:

gcc: error: unrecognized command-line option ‘--cuda-path=external/cuda_nvcc’
ybaturina commented 4 days ago

This is what I've tried:

python3.10 build/build.py --enable_cuda --use_clang=false --bazel_options=--repo_env=CC="/dt9/usr/bin/gcc" --bazel_options=--repo_env=TF_SYSROOT="/dt9" --bazel_options=--action_env=CUDA_NVCC="1" --bazel_options=--@local_config_cuda//:cuda_compiler=nvcc

The subcommand I got:

SUBCOMMAND: # //jaxlib:cpu_feature_guard.so [action 'Compiling jaxlib/cpu_feature_guard.c', configuration: 988f5a730e2bd9c88c71efcc9c7f0d36ad2ec3c5f71c922aabaf7614ff994b0f, execution platform: @local_execution_config_platform//:platform]
(cd /home/ybaturina/.cache/bazel/_bazel_ybaturina/ead9107e8e47a1c42911a02736d63d03/execroot/__main__ && \
  exec env - \
    CUDA_NVCC=1 \
    PATH=/home/kbuilder/.local/bin:/usr/local/bin/python3.10:/home/ybaturina/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin \
    PWD=/proc/self/cwd \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.d '-frandom-seed=bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.o' '-DBAZEL_CURRENT_REPOSITORY=""' -iquote . -iquote bazel-out/k8-opt/bin -iquote external/python_x86_64-unknown-linux-gnu -iquote bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu -isystem external/python_x86_64-unknown-linux-gnu/include -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include -isystem external/python_x86_64-unknown-linux-gnu/include/python3.10 -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include/python3.10 -isystem external/python_x86_64-unknown-linux-gnu/include/python3.10m -isystem bazel-out/k8-opt/bin/external/python_x86_64-unknown-linux-gnu/include/python3.10m -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' -mavx -fno-strict-aliasing -fexceptions '-fvisibility=hidden' '--sysroot=/dt9' -c jaxlib/cpu_feature_guard.c -o bazel-out/k8-opt/bin/jaxlib/_objs/cpu_feature_guard.so/cpu_feature_guard.pic.o)

I didn't get the --cuda_path option passed to the NVCC compiler.

I assume that something in the environment variables on your machine messes up the subcommand configuration. Since JAX doesn't support GCC compilation officially, I strongly recommend using clang for the compilation.