jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

Build failure: `undeclared inclusion(s) in rule '@xla//xla/stream_executor/cuda:cuda_conditional_kernels'` #19811

Open samuela opened 9 months ago

samuela commented 9 months ago

Description

I'm getting a Bazel error when attempting to build jaxlib version 0.4.24:

[2,325 / 6,066] Compiling src/idl_parser.cpp [for tool]; 16s local ... (64 actions running)
ERROR: /build/output/external/xla/xla/stream_executor/cuda/BUILD:486:13: Compiling xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc failed: undeclared inclusion(s) in rule '@xla//xla/stream_executor/cuda:cuda_conditional_kernels':
this rule is missing dependency declarations for the following files included by 'xla/stream_executor/cuda/cuda_conditional_kernels.cu.cc':
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/host_config.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/host_defines.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/common_functions.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/math_functions.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/math_functions.hpp'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/device_functions.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/device_functions.hpp'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/device_double_functions.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/device_double_functions.hpp'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/sm_70_rt.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/sm_70_rt.hpp'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/sm_80_rt.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/sm_80_rt.hpp'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/sm_90_rt.h'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/sm_90_rt.hpp'
  '/nix/store/bwqz4xqyg8dhc3n23sd46bxvds6hw52z-cuda_nvcc-11.8.89-dev/include/crt/cudacc_ext.h'
/build/output/execroot/__main__/external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc:44: DeprecationWarning: 'pipes' is deprecated and slated for removal in Python 3.13
  import pipes

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.24
jaxlib: 0.4.24
python: 3.11.7
building with CUDA enabled

happy to provide any other info that might be useful in debugging!

samuela commented 9 months ago

I have worked around this by adding nvcc headers into the build environment, but I'll leave the issue open as it seems to suggest that this is a bug in the bazel rule definition.

I'm now getting a different error:

ERROR: /build/output/external/xla/xla/service/gpu/BUILD:1449:23: Compiling xla/service/gpu/cub_sort_kernel.cu.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target @xla//xla/service/gpu:cub_sort_kernel_u64_b64) external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/external/xla/xla/service/gpu/_objs/cub_sort_kernel_u64_b64/cub_sort_kernel.cu.pic.d ... (remaining 166 arguments skipped)
/build/output/execroot/__main__/external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc:44: DeprecationWarning: 'pipes' is deprecated and slated for removal in Python 3.13
  import pipes
In file included from external/xla/xla/service/gpu/cub_sort_kernel.cu.cc:22:
external/xla/xla/service/gpu/gpu_prim_cuda.h:20:10: fatal error: cub/block/block_load.cuh: No such file or directory
   20 | #include "cub/block/block_load.cuh"
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
Target //jaxlib/tools:build_wheel failed to build

But AFAIK CUB is not mentioned as a build dependency anywhere in the jax/jaxlib documentation. Is this intended? Perhaps my build configuration is flawed somehow?