google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.61k stars 2.7k forks source link

Error Building Jaxlib v0.4.30 on Jetson Orin #22155

Open jwknaup opened 1 month ago

jwknaup commented 1 month ago

Description

When I attempt to build jaxlib v0.4.30 on a Jetson Orin with the command python build/build.py --enable_cuda --cuda_path /usr/local/cuda-12.2 --cudnn_path /usr/lib/aarch64-linux-gnu --cuda_version 12.2 --cudnn_version 8 I get the error ERROR: /home/nvidia/.cache/bazel/_bazel_nvidia/a5643b5cc286b9b13a96818003a4a7dd/external/xla/xla/stream_executor/gpu/BUILD:349:19: Compiling xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc failed: (Exit 4): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target @xla//xla/stream_executor/gpu:gpu_timer_kernel_cuda)

The full output is below

python build/build.py --enable_cuda --cuda_path /usr/local/cuda-12.2 --cudnn_path /usr/lib/aarch64-linux-gnu --cuda_version 12.2 --cudnn_version 8 

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\

Bazel binary path: ./bazel-6.5.0-linux-arm64
Bazel version: 6.5.0
Python binary path: /home/nvidia/jax/jax_env/bin/python
Python version: 3.10
Use clang: no
MKL-DNN enabled: yes
Target CPU: aarch64
Target CPU features: release
CUDA enabled: yes
CUDA toolkit path: /usr/local/cuda-12.2
CUDNN library path: /usr/lib/aarch64-linux-gnu
CUDA version: 12.2
CUDNN version: 8
NCCL enabled: yes
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
./bazel-6.5.0-linux-arm64 run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=/home/nvidia/jax/dist --jaxlib_git_hash=f4158ace933482844c145a6b919bf5dc86e084ba --cpu=aarch64
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /home/nvidia/jax/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /home/nvidia/jax/.bazelrc:
  Inherited 'build' options: --nocheck_visibility --apple_platform_type=macos --macos_minimum_os=10.14 --announce_rc --define open_source_build=true --spawn_strategy=standalone --enable_platform_specific_config --experimental_cc_shared_library --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --define=tsl_link_protobuf=true -c opt --config=short_logs --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. --@xla//xla/python:enable_gpu=false
INFO: Reading rc options for 'run' from /home/nvidia/jax/.jax_configure.bazelrc:
  Inherited 'build' options: --strategy=Genrule=standalone --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 --action_env CUDNN_INSTALL_PATH=/usr/lib/aarch64-linux-gnu --action_env TF_CUDA_PATHS=/usr/local/cuda-12.2,/usr/lib/aarch64-linux-gnu --action_env TF_CUDA_VERSION=12.2 --action_env TF_CUDNN_VERSION=8 --config=mkl_open_source_only --config=cuda --repo_env HERMETIC_PYTHON_VERSION=3.10
INFO: Found applicable config definition build:short_logs in file /home/nvidia/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:mkl_open_source_only in file /home/nvidia/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:cuda in file /home/nvidia/jax/.bazelrc: --repo_env TF_NEED_CUDA=1 --repo_env TF_NCCL_USE_STUB=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda --@xla//xla/python:enable_gpu=true --@xla//xla/python:jax_cuda_pip_rpaths=true --define=xla_python_enable_gpu=true --linkopt=-Wl,--disable-new-dtags
INFO: Found applicable config definition build:linux in file /home/nvidia/jax/.bazelrc: --config=posix --copt=-Wno-unknown-warning-option --copt=-Wno-stringop-truncation --copt=-Wno-array-parameter
INFO: Found applicable config definition build:posix in file /home/nvidia/jax/.bazelrc: --copt=-fvisibility=hidden --copt=-Wno-sign-compare --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
Loading: 
DEBUG: /home/nvidia/.cache/bazel/_bazel_nvidia/a5643b5cc286b9b13a96818003a4a7dd/external/bazel_tools/tools/cpp/lib_cc_configure.bzl:118:10: 
Auto-Configuration Warning: 'TMP' environment variable is not set, using 'C:\Windows\Temp' as default
DEBUG: /home/nvidia/.cache/bazel/_bazel_nvidia/a5643b5cc286b9b13a96818003a4a7dd/external/bazel_tools/tools/cpp/lib_cc_configure.bzl:118:10: 
Auto-Configuration Warning: 'TMP' environment variable is not set, using 'C:\Windows\Temp' as default
Loading: 
Loading: 2 packages loaded
Analyzing: target //jaxlib/tools:build_wheel (3 packages loaded, 0 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (36 packages loaded, 10 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (36 packages loaded, 10 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (36 packages loaded, 10 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (55 packages loaded, 216 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (57 packages loaded, 546 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (91 packages loaded, 3986 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (123 packages loaded, 4533 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (137 packages loaded, 4966 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (156 packages loaded, 5504 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (159 packages loaded, 7525 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (180 packages loaded, 9740 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (208 packages loaded, 13196 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (219 packages loaded, 16021 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (225 packages loaded, 18175 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (236 packages loaded, 20734 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (238 packages loaded, 22135 targets configured)
INFO: Analyzed target //jaxlib/tools:build_wheel (239 packages loaded, 23156 targets configured).
 checking cached actions
INFO: Found 1 target...
[1 / 5] [Prepa] BazelWorkspaceStatusAction stable-status.txt
[24 / 3,232] Compiling third_party/f2reduce/f2reduce.cpp; 0s local ... (2 actions, 1 running)
[56 / 3,297] Compiling third_party/f2reduce/f2reduce.cpp; 1s local ... (4 actions, 3 running)
[182 / 3,363] Compiling llvm/lib/Demangle/RustDemangle.cpp [for tool]; 0s local ... (4 actions, 3 running)
[792 / 3,598] [Prepa] Executing genrule @local_config_cuda//cuda:cuda-lib; 4s ... (5 actions, 3 running)
[1,699 / 3,610] Executing genrule @local_config_cuda//cuda:cuda-lib; 1s local ... (5 actions, 4 running)
[1,824 / 3,642] Executing genrule @local_config_cuda//cuda:cuda-lib; 2s local ... (6 actions, 5 running)
[1,839 / 3,658] Executing genrule @local_config_cuda//cuda:cuda-lib; 3s local ... (6 actions running)
[1,844 / 3,661] Executing genrule @local_config_cuda//cuda:cuda-lib; 4s local ... (6 actions running)
[1,846 / 3,661] Executing genrule @local_config_cuda//cuda:cuda-lib; 5s local ... (6 actions running)
[1,854 / 3,704] Executing genrule @local_config_cuda//cuda:cuda-lib; 6s local ... (6 actions running)
[1,860 / 3,719] Compiling src/google/protobuf/wire_format_lite.cc [for tool]; 1s local ... (6 actions running)
[1,869 / 3,762] Compiling src/google/protobuf/wire_format_lite.cc [for tool]; 2s local ... (4 actions running)
[1,870 / 3,806] Compiling llvm/utils/TableGen/Attributes.cpp [for tool]; 2s local ... (6 actions running)
[2,561 / 5,511] Compiling llvm/lib/Demangle/ItaniumDemangle.cpp; 3s local ... (4 actions running)
[3,603 / 5,645] Compiling llvm/lib/Demangle/ItaniumDemangle.cpp; 4s local ... (5 actions running)
[3,675 / 5,873] Compiling llvm/lib/Demangle/ItaniumDemangle.cpp; 5s local ... (4 actions running)
[3,707 / 6,267] Compiling llvm/lib/Demangle/ItaniumDemangle.cpp; 6s local ... (6 actions running)
[3,709 / 6,267] Compiling llvm/lib/Demangle/ItaniumDemangle.cpp; 7s local ... (6 actions running)
[3,712 / 6,267] Compiling llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp [for tool]; 4s local ... (6 actions running)
[3,715 / 6,417] Compiling llvm/utils/TableGen/Basic/CodeGenIntrinsics.cpp [for tool]; 5s local ... (6 actions running)
[3,723 / 6,923] Compiling xla/service/gpu/stream_executor_util_kernel.cu.cc; 4s local ... (5 actions running)
[3,725 / 7,584] Compiling xla/service/gpu/stream_executor_util_kernel.cu.cc; 5s local ... (6 actions running)
[3,726 / 7,584] Compiling xla/service/gpu/stream_executor_util_kernel.cu.cc; 6s local ... (6 actions running)
[3,731 / 7,584] Compiling xla/service/gpu/stream_executor_util_kernel.cu.cc; 7s local ... (6 actions running)
[3,735 / 7,584] Compiling xla/service/gpu/stream_executor_util_kernel.cu.cc; 8s local ... (6 actions running)
[3,736 / 7,584] Compiling xla/service/gpu/stream_executor_util_kernel.cu.cc; 9s local ... (6 actions, 5 running)
[3,747 / 8,596] Compiling xla/service/gpu/stream_executor_util_kernel.cu.cc; 11s local ... (6 actions, 5 running)
[3,754 / 8,683] Compiling llvm/utils/TableGen/DirectiveEmitter.cpp [for tool]; 5s local ... (6 actions, 5 running)
.
.
.
[9,790 / 14,610] Compiling xla/service/gpu/kernels/topk_kernel_bfloat16.cu.cc; 468s local ... (6 actions running)
ERROR: /home/nvidia/.cache/bazel/_bazel_nvidia/a5643b5cc286b9b13a96818003a4a7dd/external/xla/xla/stream_executor/gpu/BUILD:349:19: Compiling xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc failed: (Exit 4): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target @xla//xla/stream_executor/gpu:gpu_timer_kernel_cuda) 
  (cd /home/nvidia/.cache/bazel/_bazel_nvidia/a5643b5cc286b9b13a96818003a4a7dd/execroot/__main__ && \
  exec env - \
    CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 \
    CUDNN_INSTALL_PATH=/usr/lib/aarch64-linux-gnu \
    LD_LIBRARY_PATH=/home/nvidia/colcon_ws/install/zed_components/lib:/home/nvidia/colcon_ws/install/zed_interfaces/lib:/home/nvidia/colcon_ws/install/vesc_driver/lib:/home/nvidia/colcon_ws/install/vesc_ackermann/lib:/home/nvidia/colcon_ws/install/custom_ackermann/lib:/home/nvidia/colcon_ws/install/vesc_msgs/lib:/home/nvidia/colcon_ws/install/teleop_tools_msgs/lib:/opt/ros/humble/opt/rviz_ogre_vendor/lib:/opt/ros/humble/lib/aarch64-linux-gnu:/opt/ros/humble/lib:/usr/local/cuda-12.2/lib64: \
    PATH=/home/nvidia/jax/jax_env/bin:/opt/ros/humble/bin:/usr/local/cuda-12.2/bin:/home/nvidia/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/snap/bin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 \
    TF_CUDA_PATHS=/usr/local/cuda-12.2,/usr/lib/aarch64-linux-gnu \
    TF_CUDA_VERSION=12.2 \
    TF_CUDNN_VERSION=8 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/aarch64-opt/bin/external/xla/xla/stream_executor/gpu/_objs/gpu_timer_kernel_cuda/gpu_timer_kernel_cuda.cu.pic.d '-frandom-seed=bazel-out/aarch64-opt/bin/external/xla/xla/stream_executor/gpu/_objs/gpu_timer_kernel_cuda/gpu_timer_kernel_cuda.cu.pic.o' '-DEIGEN_MAX_ALIGN_BYTES=64' -DEIGEN_ALLOW_UNALIGNED_SCALARS '-DEIGEN_USE_AVX512_GEMM_KERNELS=0' -DHAVE_SYS_UIO_H -DTF_USE_SNAPPY '-DGOOGLE_CUDA=1' '-DBAZEL_CURRENT_REPOSITORY="xla"' -iquote external/xla -iquote bazel-out/aarch64-opt/bin/external/xla -iquote external/tsl -iquote bazel-out/aarch64-opt/bin/external/tsl -iquote external/eigen_archive -iquote bazel-out/aarch64-opt/bin/external/eigen_archive -iquote external/ml_dtypes -iquote bazel-out/aarch64-opt/bin/external/ml_dtypes -iquote external/com_google_absl -iquote bazel-out/aarch64-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/aarch64-opt/bin/external/nsync -iquote external/local_config_cuda -iquote bazel-out/aarch64-opt/bin/external/local_config_cuda -iquote external/com_google_protobuf -iquote bazel-out/aarch64-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/aarch64-opt/bin/external/zlib -iquote external/double_conversion -iquote bazel-out/aarch64-opt/bin/external/double_conversion -iquote external/snappy -iquote bazel-out/aarch64-opt/bin/external/snappy -iquote external/com_googlesource_code_re2 -iquote bazel-out/aarch64-opt/bin/external/com_googlesource_code_re2 -Ibazel-out/aarch64-opt/bin/external/ml_dtypes/_virtual_includes/float8 -Ibazel-out/aarch64-opt/bin/external/ml_dtypes/_virtual_includes/intn -Ibazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -isystem external/eigen_archive -isystem bazel-out/aarch64-opt/bin/external/eigen_archive -isystem external/eigen_archive/mkl_include -isystem bazel-out/aarch64-opt/bin/external/eigen_archive/mkl_include -isystem external/ml_dtypes -isystem bazel-out/aarch64-opt/bin/external/ml_dtypes -isystem external/ml_dtypes/ml_dtypes -isystem bazel-out/aarch64-opt/bin/external/ml_dtypes/ml_dtypes -isystem external/nsync/public -isystem bazel-out/aarch64-opt/bin/external/nsync/public -isystem external/local_config_cuda/cuda -isystem bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/aarch64-opt/bin/external/local_config_cuda/cuda/cuda/include -isystem external/com_google_protobuf/src -isystem bazel-out/aarch64-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/aarch64-opt/bin/external/zlib -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canonical-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections '-fvisibility=hidden' -Wno-sign-compare -Wno-unknown-warning-option -Wno-stringop-truncation -Wno-array-parameter '-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.' '-std=c++17' -x cuda '-DGOOGLE_CUDA=1' '--cuda-gpu-arch=sm_50' '--cuda-gpu-arch=sm_60' '--cuda-gpu-arch=sm_70' '--cuda-gpu-arch=sm_80' '--cuda-include-ptx=sm_90' '--cuda-gpu-arch=sm_90' '-Xcuda-fatbinary=--compress-all' '-nvcc_options=expt-relaxed-constexpr' -c external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc -o bazel-out/aarch64-opt/bin/external/xla/xla/stream_executor/gpu/_objs/gpu_timer_kernel_cuda/gpu_timer_kernel_cuda.cu.pic.o)
# Configuration: 39ab92a7342907b6f4e73040033499a7cf871005fd5792c463496975725a1dc4
# Execution platform: @local_execution_config_platform//:platform
In file included from external/xla/xla/stream_executor/stream_executor.h:2,
                 from external/xla/xla/stream_executor/stream_executor_common.h:29,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:60,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/com_google_absl/absl/log/log.h:199: warning: "LOG" redefined
  199 | #define LOG(severity) ABSL_LOG_INTERNAL_LOG_IMPL(_##severity)
      | 
In file included from external/tsl/tsl/platform/logging.h:26,
                 from external/xla/xla/stream_executor/device_memory.h:33,
                 from external/xla/xla/stream_executor/blas.h:35,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:44,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/tsl/tsl/platform/default/logging.h:165: note: this is the location of the previous definition
  165 | #define LOG(severity) _TF_LOG_##severity
      | 
In file included from external/xla/xla/stream_executor/stream_executor.h:2,
                 from external/xla/xla/stream_executor/stream_executor_common.h:29,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:60,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/com_google_absl/absl/log/log.h:237: warning: "LOG_EVERY_N" redefined
  237 | #define LOG_EVERY_N(severity, n) \
      | 
In file included from external/tsl/tsl/platform/logging.h:26,
                 from external/xla/xla/stream_executor/device_memory.h:33,
                 from external/xla/xla/stream_executor/blas.h:35,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:44,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/tsl/tsl/platform/default/logging.h:278: note: this is the location of the previous definition
  278 | #define LOG_EVERY_N(severity, n)                       \
      | 
In file included from external/xla/xla/stream_executor/stream_executor.h:2,
                 from external/xla/xla/stream_executor/stream_executor_common.h:29,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:60,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/com_google_absl/absl/log/log.h:245: warning: "LOG_FIRST_N" redefined
  245 | #define LOG_FIRST_N(severity, n) \
      | 
In file included from external/tsl/tsl/platform/logging.h:26,
                 from external/xla/xla/stream_executor/device_memory.h:33,
                 from external/xla/xla/stream_executor/blas.h:35,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:44,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/tsl/tsl/platform/default/logging.h:284: note: this is the location of the previous definition
  284 | #define LOG_FIRST_N(severity, n)                       \
      | 
In file included from external/xla/xla/stream_executor/stream_executor.h:2,
                 from external/xla/xla/stream_executor/stream_executor_common.h:29,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:60,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/com_google_absl/absl/log/log.h:253: warning: "LOG_EVERY_POW_2" redefined
  253 | #define LOG_EVERY_POW_2(severity) \
      | 
In file included from external/tsl/tsl/platform/logging.h:26,
                 from external/xla/xla/stream_executor/device_memory.h:33,
                 from external/xla/xla/stream_executor/blas.h:35,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:44,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/tsl/tsl/platform/default/logging.h:290: note: this is the location of the previous definition
  290 | #define LOG_EVERY_POW_2(severity)                         \
      | 
In file included from external/xla/xla/stream_executor/stream_executor.h:2,
                 from external/xla/xla/stream_executor/stream_executor_common.h:29,
                 from external/xla/xla/stream_executor/gpu/gpu_executor.h:60,
                 from external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc:19:
external/com_google_absl/absl/log/log.h:265: warning: "LOG_EVERY_N_SEC" redefined
  .
.
.
  536 | #define DCHECK_GT(x, y) _TF_DCHECK_NOP(x, y)
      | 
external/com_google_absl/absl/status/status.h(796): warning #2810-D: ignoring return value type with "nodiscard" attribute
      *this = std::move(new_status);
            ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

external/com_google_absl/absl/status/internal/statusor_internal.h(240): warning #2810-D: ignoring return value type with "nodiscard" attribute
        status_ = OkStatus();
                ^

external/com_google_absl/absl/status/internal/statusor_internal.h(247): warning #2810-D: ignoring return value type with "nodiscard" attribute
      status_ = static_cast<absl::Status>(std::forward<U>(v));
              ^

/usr/lib/gcc/aarch64-linux-gnu/11/include/arm_neon.h(38): error: identifier "__Int8x8_t" is undefined
  typedef __Int8x8_t int8x8_t;
          ^

/usr/lib/gcc/aarch64-linux-gnu/11/include/arm_neon.h(39): error: identifier "__Int16x4_t" is undefined
  typedef __Int16x4_t int16x4_t;
          ^

/usr/lib/gcc/aarch64-linux-gnu/11/include/arm_neon.h(40): error: identifier "__Int32x2_t" is undefined
  typedef __Int32x2_t int32x2_t;
          ^

/usr/lib/gcc/aarch64-linux-gnu/11/include/arm_neon.h(41): error: identifier "__Int64x1_t" is undefined
  typedef __Int64x1_t int64x1_t;
          ^

.
.
.
/usr/lib/gcc/aarch64-linux-gnu/11/include/arm_neon.h(1308): error: identifier "__builtin_aarch64_addhn2v4si" is undefined
    return (uint16x8_t) __builtin_aarch64_addhn2v4si ((int16x4_t) __a,
                        ^

/usr/lib/gcc/aarch64-linux-gnu/11/include/arm_neon.h(1317): error: identifier "__builtin_aarch64_addhn2v2di" is undefined
    return (uint32x4_t) __builtin_aarch64_addhn2v2di ((int32x2_t) __a,
                        ^

Error limit reached.
100 errors detected in the compilation of "external/xla/xla/stream_executor/gpu/gpu_timer_kernel_cuda.cu.cc".
Compilation terminated.
Target //jaxlib/tools:build_wheel failed to build
INFO: Elapsed time: 1536.946s, Critical Path: 471.27s
INFO: 3903 processes: 22 internal, 3881 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
Traceback (most recent call last):
  File "/home/nvidia/jax/build/build.py", line 726, in <module>
    main()
  File "/home/nvidia/jax/build/build.py", line 692, in main
    shell(build_cpu_wheel_command)
  File "/home/nvidia/jax/build/build.py", line 45, in shell
    output = subprocess.check_output(cmd)
  File "/usr/lib/python3.10/subprocess.py", line 421, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/usr/lib/python3.10/subprocess.py", line 526, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-6.5.0-linux-arm64', 'run', '--verbose_failures=true', '//jaxlib/tools:build_wheel', '--', '--output_path=/home/nvidia/jax/dist', '--jaxlib_git_hash=f4158ace933482844c145a6b919bf5dc86e084ba', '--cpu=aarch64']' returned non-zero exit status 1.

System info (python version, jaxlib version, accelerator, etc.)

Bazel binary path: ./bazel-6.5.0-linux-arm64
Bazel version: 6.5.0
Python binary path: /home/nvidia/jax/jax_env/bin/python
Python version: 3.10
Use clang: no
MKL-DNN enabled: yes
Target CPU: aarch64
Target CPU features: release
CUDA enabled: yes
CUDA toolkit path: /usr/local/cuda-12.2
CUDNN library path: /usr/lib/aarch64-linux-gnu
CUDA version: 12.2
CUDNN version: 8
NCCL enabled: yes
ROCm enabled: no

I am using a Jetson Orin Nano with Tegra GPU running Jetpack 6 Cuda 12.2 cuDNN 8

kylestach commented 1 month ago

Seeing the same issue on a Jetson Orin AGX with JP6. I think it should be fixable with the patch from https://github.com/tensorflow/tensorflow/issues/62490 (applied to xla -> tsl -> absl). I'm building with the patch now, I'll comment again if this fixes the issue.

kylestach commented 1 month ago

FYI the above workaround works as a one-off; it would be nice to get the fix upstreamed once https://github.com/abseil/abseil-cpp/issues/1665 is resolved.

jwknaup commented 1 month ago

Thanks Kyle! Worked for me as well

JuanCarlos-TiqueRangel commented 1 month ago

@kylestach Hi, I am facing the same issue here with the Jetson Orin Nano. Did you figure out ?

adamjstewart commented 2 weeks ago

Note that this has been fixed upstream in absl: https://github.com/abseil/abseil-cpp/pull/1732. Maybe we can bump the vendored copy of absl before the next release?

nouiz commented 2 weeks ago

There is already a PR to add this patch in XLA: https://github.com/openxla/xla/pull/15687 So it should be fixed soon.

adamjstewart commented 2 weeks ago

Did you link the wrong PR?

nouiz commented 2 weeks ago

Updated my previous comment.

adamjstewart commented 2 weeks ago

In that case can we bump the vendored copy of XLA? Or just stop vendoring things so it's no longer a JAX bug?

nouiz commented 2 weeks ago

JAX seem to update the XLA commit everyday or so. I looked, current JAX upstream point to an XLA that has the patch. So taking today/tomorrow JAX nightly or building upstream should get you the fix.

If you try, can you confirm the fix works for you?