jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.31k stars 2.78k forks source link

Building on linux with ppc64le CPU #4493

Open f0uriest opened 4 years ago

f0uriest commented 4 years ago

I'm trying to build jax on a cluster that uses IBM power9 processors (it's a sister cluster to Summit at ORNL). It seems to be failing when trying to build XLA, which is strange because I've been able to install tensorflow just fine. The full output log is here: https://gist.github.com/f0uriest/5f04e2ed9916bb750a9ea679633ac80c

Any ideas? Is there any plan to offer pre-build wheels for ppc64le architecture?

proutrc commented 3 years ago

@f0uriest few questions:

What compiler and version are you using? Is it installed in /usr/bin?

f0uriest commented 3 years ago

gcc version 8.3.1 20191121 (Red Hat 8.3.1-5) (GCC) installed in /usr/bin

proutrc commented 3 years ago

Thanks, it looks like we needed to add the patch referenced in this issue: https://github.com/tensorflow/tensorflow/issues/33975

Otherwise, TF does not pick up the proper inclusion paths for the compiler outside /usr/bin.

Specifically added patch to bazel-jax/external/org_tensorflow/third_party/gpus/cuda_configure.bzl, which was in the build-cache.

That got us passed the inclusion errors.

proutrc commented 3 years ago

@hawkinsp

Just to follow up more - we made it further, but we seem to be hitting similar issues with other packages that assume the location of GCC. My guess is that this will happen at various steps, the first being Tensorflow and now llvm-project.

Recent error within external\llvm-project/mlir

[4,727 / 5,554] checking cached actions
ERROR: /gpfs/alpine/stf007/scratch/rprout/bazel-build-cache/user-root/b2ebe10a0ad0f6175e81a930563cb9d3/external/llvm-project/mlir/BUILD:173:18: TdGenerate external/llvm-project/mlir/include/mlir/IR/BuiltinLocationAttributes.h.inc failed: (Exit 1): mlir-tblgen failed: error executing command 
  (cd /gpfs/alpine/stf007/scratch/rprout/bazel-build-cache/user-root/b2ebe10a0ad0f6175e81a930563cb9d3/execroot/__main__ && \
  exec env - \
  bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen --gen-attrdef-decls external/llvm-project/mlir/include/mlir/IR/BuiltinLocationAttributes.td -I external/llvm-project/mlir/include -I bazel-out/ppc-opt/bin/external/llvm-project/mlir/include -I external/llvm-project/ -I bazel-out/ppc-opt/bin/external/llvm-project/ -I external/llvm-project/mlir/include/mlir/IR -I bazel-out/ppc-opt/bin/external/llvm-project/mlir/include/mlir/IR -o bazel-out/ppc-opt/bin/external/llvm-project/mlir/include/mlir/IR/BuiltinLocationAttributes.h.inc)
Execution platform: @local_execution_config_platform//:platform
bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen: /lib64/libstdc++.so.6: version `CXXABI_1.3.8' not found (required by bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen)
bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.20' not found (required by bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen)
bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found (required by bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen)
bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen: /lib64/libstdc++.so.6: version `CXXABI_1.3.9' not found (required by bazel-out/ppc-opt-exec-50AE0418/bin/external/llvm-project/mlir/mlir-tblgen)
Target //build:build_wheel failed to build
INFO: Elapsed time: 70.123s, Critical Path: 16.09s
INFO: 68 processes: 55 internal, 13 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully

Note, it appears our env looks empty (I don't see LD_LIBRARY_PATH or anything). Not sure why this is happening all of a sudden.

So, it is correctly complaining about the system provided GCC (/usr/bin). But, we are trying to use the non-system GCC in /sw/summit/gcc/8.1.1/bin/gcc and /sw/summit/gcc/8.1.1/lib64.

This works up until various points with external packages. We were able to fix the similar Tensorflow issue with the patch we referenced, but are now hitting it with llvm-project.

For HPC systems that have the proper GCC in /usr/bin everything seems to go smoothly (like the PPPL folks). That seems like a bad assumption by all these packages though, from an HPC perspective. I realize this is not all within the JAX realm, but it is worth noting.

Is there a way to make our toolchain known to these other packages?And/or is there a necessary change needed in these various packages BUILD files? I could be thinking about this incorrectly, as my familiarity with Bazel is limited.

If by chance you see a way to help, please do!

hawkinsp commented 3 years ago

I think what that means is that mlir-tblgen was built by Bazel, but the result does not run. This indicates either something misconfigured with the toolchain or a dynamic linker path problem. It's probably the former.

One thing I might try is a non-CUDA build and see if it works. CUDA requires that TF do some trickery with toolchains to support the CUDA compilers.

Another option might be to cross-compile, as I mentioned above. The main difficulty there will be that you will need to target a suitable glibc version.

Yet another option might be to compile in something like a docker container that has a more standard layout. I'm not sure if that's feasible for you.

mrorro commented 3 years ago

I was able to build from main on Traverse without having to do any toolchain modifications, though I did have to manually specify the cuda/cudnn paths.

python build/build.py --enable_cuda --cuda_path /usr/local/cuda-11.3 --cuda_version=11.3 --cudnn_version=8.2.0 --cudnn_path /usr/local/cudnn/cuda-11.3/8.2.0 --noenable_mkl_dnn --cuda_compute_capabilities 7.0 --bazel_path /usr/bin/bazel --target_cpu=ppc

Thanks so much for your help with this! Hope the other ppc users can also get it working

The same worked for me with bazel 3.7.2 whereas I got errors with bazel 3.1.0 Thank you @hawkinsp and all.

proutrc commented 3 years ago

We also got it to work without the toolchain modifications on a smaller test ppc system of ours, which has RHEL8 and GCC 8.3 in /usr/bin.

On Summit, we don't yet have RHEL8 and have GCC 4.8 in /usr/bin. We just can't seem get the GCC 8.x compiler, provided from our module system (i.e /sw/summit/gcc/8.1.1), known to all the Bazel build recipes.

boegel commented 3 years ago

@proutrc For jaxlib 0.1.70, we have a fix for the mlir-tblgen issue, see this patch.

However for jaxlib 0.1.71, I'm running into the same problem again, and I'm not sure how to fix it there because things have shifted around enough to make the patch we came up with useless... I've opened a dedicated issue on this: #7842

mshafiei commented 2 years ago

I'm getting this error for the most recent tag (jax-v0.2.28). More specifically, the error is as follows,

ERROR: /home/mohammad/.cache/bazel/_bazel_mohammad/5a0cafedcffcc5a6733cc68df657e72e/external/llvm-project/mlir/BUILD.bazel:3356:11: Compiling mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp failed: undeclared inclusion(s) in rule '@llvm-project//mlir:GPUToVulkanTransforms': this rule is missing dependency declarations for the following files included by 'mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp': 'bazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/GPUBaseIncGen/mlir/Dialect/GPU/GPUOpsDialect.h.inc' Target //build:build_wheel failed to build INFO: Elapsed time: 464.234s, Critical Path: 90.53s INFO: 1117 processes: 176 internal, 941 local. FAILED: Build did NOT complete successfully ERROR: Build failed. Not running target FAILED: Build did NOT complete successfully

hawkinsp commented 2 years ago

@mshafiei I don't see that error building at that tag. One thing that's worth trying: try running bazel clean using the copy of bazel being used for the build and then try the build again. The LLVM Bazel rules sometimes confuse Bazel.

mshafiei commented 2 years ago

@hawkinsp It worked. Thanks

adamdejl commented 2 years ago

I recently went through building and installing jaxlib 0.1.76 on MIT IBM Satori, a ppc64le machine, and I thought I would share my steps in case anyone happens to find them helpful:

  1. I installed bazel using a community-created binary for ppc64le. In my case, I had to use Bazel 4.2.2, as newer versions were incompatible with the system libraries on IBM Satori (this also meant I could only build an older version of jaxlib, as the more recent ones seem to require Bazel >=5.0)
  2. I set up and activated a dedicated Anaconda environment with cudatoolkit and loaded a system CUDA 11.1 module
  3. I downloaded the source for the desired release (jaxlib-v0.1.76 in my case), unpacked it and run python build/build.py --enable_cuda --cuda_path /software/cuda/11.1 --cudnn_path /nobackup/users/<username>/anaconda3/envs/<environment name> in the its root. The CUDA and cuDNN paths specified in the build command need to be adapted for the specific system.
  4. I installed the resulting .whl package (found in the /dist subdirectory once the build completed successfully) using pip.

The resulting jaxlib installation seems to work well for our purposes, although it can only be used for running on a GPU. Attempting to run on a CPU triggers errors 'pwr9' is not a recognized processor for this target (ignoring processor).

C-J-Cundy commented 2 years ago

@proutrc @feifzhou Have you had any success on this issue since last year? I am trying to install JAX on the Lassen PPC system and having a lot of difficulties. I am able to get around most of the undeclared dependency issues by making toolchain files as described above and installing following these commands, but end up getting this error from Eigen during compilation.

Is there perhaps a pre-configured/compiled package available for JAX on PPC/Lassen similar to the opence distributions of torch and other packages? I've been trying unsucessfully to install JAX for the last several days now

mrorro commented 2 years ago

@C-J-Cundy did you try the jaxlib anaconda package for linux-ppc64le?

C-J-Cundy commented 2 years ago

@mrorro Yes, when trying that I get the erorr ImportError: /lib64/ld64.so.2: version `GLIBC_2.22' not found (required by /g/g92/cundy2/.conda/envs/rljax/lib/python3.8/site-packages/jaxlib/xla_extension.so) It seems like the conda version is built against glibc_2.22, which isn't present on my system (and building it myself breaks everything else on the RHEL7 system)

allen-adastra commented 1 year ago

Looks like Open-CE 1.9 should allow you to use Jax on powerpc platforms.

https://github.com/open-ce/open-ce/issues/549

dpanici commented 1 year ago

@allen-adastra Do you have an example of how to use? is it as simple as doin something like conda install -c https://ftp.osuosl.org/pub/open-ce/current/ jax?

bhuntsman commented 4 months ago

Is anyone still working on this? I tried a build using Bazel 7.2.0 built from the dist archive. I have cuda 11.1.1, cuDNN 8.9.7.29, gcc 9.4.0. The system is running Ubuntu 20.04 and is a Power 8 S822LC HPC (Minsky).

I'm trying to build from the git HEAD:

$ TMP=/tmp /usr/bin/python3.9 build/build.py --python_version=3.9 --noenable_mkl_dnn --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=11

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\

Bazel binary path: /usr/local/bin/bazel
Bazel version: 7.2.0
Use clang: no
MKL-DNN enabled: no
Target CPU: ppc64le
Target CPU features: release
CUDA enabled: yes
NCCL enabled: yes
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
/usr/local/bin/bazel run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=/home/build/repos/jax/dist --jaxlib_git_hash=d3bfd32667d7742d2fc208b75d4085e923d5e03f --cpu=ppc64le --skip_gpu_kernels
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /home/build/repos/jax/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /home/build/repos/jax/.bazelrc:
  Inherited 'build' options: --nocheck_visibility --apple_platform_type=macos --macos_minimum_os=10.14 --announce_rc --define open_source_build=true --spawn_strategy=standalone --enable_platform_specific_config --experimental_cc_shared_library --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --define=tsl_link_protobuf=true -c opt --config=short_logs --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. --@xla//xla/python:enable_gpu=false
INFO: Reading rc options for 'run' from /home/build/repos/jax/.jax_configure.bazelrc:
  Inherited 'build' options: --strategy=Genrule=standalone --config=cuda --config=cuda_plugin --repo_env HERMETIC_PYTHON_VERSION=3.9
INFO: Found applicable config definition build:short_logs in file /home/build/repos/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:cuda in file /home/build/repos/jax/.bazelrc: --repo_env TF_NEED_CUDA=1 --repo_env TF_NCCL_USE_STUB=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda --@xla//xla/python:enable_gpu=true --@xla//xla/python:jax_cuda_pip_rpaths=true --define=xla_python_enable_gpu=true --linkopt=-Wl,--disable-new-dtags
INFO: Found applicable config definition build:cuda_plugin in file /home/build/repos/jax/.bazelrc: --@xla//xla/python:enable_gpu=false --define=xla_python_enable_gpu=false
INFO: Found applicable config definition build:linux in file /home/build/repos/jax/.bazelrc: --config=posix --copt=-Wno-unknown-warning-option --copt=-Wno-stringop-truncation --copt=-Wno-array-parameter
INFO: Found applicable config definition build:posix in file /home/build/repos/jax/.bazelrc: --copt=-fvisibility=hidden --copt=-Wno-sign-compare --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
Computing main repo mapping: 
Loading: 
Loading: 0 packages loaded
Analyzing: target //jaxlib/tools:build_wheel (0 packages loaded, 0 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (0 packages loaded, 0 targets configured)
[0 / 1] [Prepa] BazelWorkspaceStatusAction stable-status.txt
ERROR: /home/build/.cache/bazel/_bazel_build/dcc0c88d87f0eff69c7a10c9d397572b/external/local_config_python/BUILD:16:11: in py_runtime rule @@local_config_python//:py3_runtime: 
Traceback (most recent call last):
    File "/virtual_builtins_bzl/common/python/py_runtime_rule.bzl", line 40, column 17, in _py_runtime_impl
Error in fail: interpreter_path must be an absolute path
ERROR: /home/build/.cache/bazel/_bazel_build/dcc0c88d87f0eff69c7a10c9d397572b/external/local_config_python/BUILD:16:11: Analysis of target '@@local_config_python//:py3_runtime' failed
ERROR: Analysis of target '//jaxlib/tools:build_wheel' failed; build aborted: Analysis failed
INFO: Elapsed time: 0.253s, Critical Path: 0.00s
INFO: 1 process: 1 internal.
ERROR: Build did NOT complete successfully
FAILED: 
ERROR: Build failed. Not running target
Traceback (most recent call last):
  File "/home/build/repos/jax/build/build.py", line 726, in <module>
    main()
  File "/home/build/repos/jax/build/build.py", line 692, in main
    shell(build_cpu_wheel_command)
  File "/home/build/repos/jax/build/build.py", line 45, in shell
    output = subprocess.check_output(cmd)
  File "/usr/lib/python3.9/subprocess.py", line 424, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/usr/lib/python3.9/subprocess.py", line 528, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/usr/local/bin/bazel', 'run', '--verbose_failures=true', '//jaxlib/tools:build_wheel', '--', '--output_path=/home/build/repos/jax/dist', '--jaxlib_git_hash=d3bfd32667d7742d2fc208b75d4085e923d5e03f', '--cpu=ppc64le', '--skip_gpu_kernels']' returned non-zero exit status 1.

Looks like the error is Error in fail: interpreter_path must be an absolute path. Is that a Bazel problem, jax problem, or what? Many thanks!

dsuarez01 commented 3 months ago

@bhuntsman I had a similar issue as you, and downgraded to Bazel 6.5.0. This resolved the issue of not finding the interpreter path, only to result in this error:

(sum24r) python build/build.py --enable_cuda --cuda_path /nobackup/users/dsuarez/miniconda3/envs/sum24r/pkgs/cuda-toolkit --cuda_version=12.2 --cudnn_v
ersion=8.9.6 --cudnn_path /nobackup/users/dsuarez/miniconda3/envs/sum24r --noenable_mkl_dnn --cuda_compute_capabilities 7.0 --bazel_path /nobackup/user
s/dsuarez/miniconda3/envs/sum24r/bin/bazel --target_cpu=ppc --python_version=3.11
     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\
Bazel binary path: /nobackup/users/dsuarez/miniconda3/envs/sum24r/bin/bazel
Bazel version: 6.5.0
Use clang: no
MKL-DNN enabled: no
Target CPU: ppc64le
Target CPU features: release
CUDA enabled: yes
CUDA toolkit path: /nobackup/users/dsuarez/miniconda3/envs/sum24r/pkgs/cuda-toolkit
CUDNN library path: /nobackup/users/dsuarez/miniconda3/envs/sum24r
CUDA compute capabilities: 7.0
CUDA version: 12.2
CUDNN version: 8.9.6
NCCL enabled: yes
ROCm enabled: no

Building XLA and installing it in the jaxlib source tree...
/nobackup/users/dsuarez/miniconda3/envs/sum24r/bin/bazel run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=/nobackup/users/dsuarez/jax/dist --jaxlib_git_hash=24b42eed5ea4163ac5a3d4a8f8648545572933db --cpu=ppc64le
Starting local Bazel server and connecting to it...
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /nobackup/users/dsuarez/jax/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /nobackup/users/dsuarez/jax/.bazelrc:
  Inherited 'build' options: --nocheck_visibility --apple_platform_type=macos --macos_minimum_os=10.14 --announce_rc --define open_source_build=true --spawn_strategy=standalone --enable_platform_specific_config --experimental_cc_shared_library --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --define=tsl_link_protobuf=true -c opt --config=short_logs --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. --@xla//xla/python:enable_gpu=false
INFO: Reading rc options for 'run' from /nobackup/users/dsuarez/jax/.jax_configure.bazelrc:
  Inherited 'build' options: --strategy=Genrule=standalone --action_env CUDA_TOOLKIT_PATH=/nobackup/users/dsuarez/miniconda3/envs/sum24r/pkgs/cuda-toolkit --action_env CUDNN_INSTALL_PATH=/nobackup/users/dsuarez/miniconda3/envs/sum24r --action_env TF_CUDA_PATHS=/nobackup/users/dsuarez/miniconda3/envs/sum24r/pkgs/cuda-toolkit,/nobackup/users/dsuarez/miniconda3/envs/sum24r --action_env TF_CUDA_VERSION=12.2 --action_env TF_CUDNN_VERSION=8.9.6 --cpu=ppc --config=cuda --repo_env HERMETIC_PYTHON_VERSION=3.11
INFO: Found applicable config definition build:short_logs in file /nobackup/users/dsuarez/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:cuda in file /nobackup/users/dsuarez/jax/.bazelrc: --repo_env TF_NEED_CUDA=1 --repo_env TF_NCCL_USE_STUB=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda --@xla//xla/python:enable_gpu=true --@xla//xla/python:jax_cuda_pip_rpaths=true --define=xla_python_enable_gpu=true --linkopt=-Wl,--disable-new-dtags
INFO: Found applicable config definition build:cuda in file /nobackup/users/dsuarez/jax/.jax_configure.bazelrc: --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.0
INFO: Found applicable config definition build:linux in file /nobackup/users/dsuarez/jax/.bazelrc: --config=posix --copt=-Wno-unknown-warning-option --copt=-Wno-stringop-truncation --copt=-Wno-array-parameter
INFO: Found applicable config definition build:posix in file /nobackup/users/dsuarez/jax/.bazelrc: --copt=-fvisibility=hidden --copt=-Wno-sign-compare --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
Loading: 
Loading: 
Loading: 
DEBUG: /home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/xla/third_party/py/python_repo.bzl:110:10: Using hermetic Python 3.11
Loading: 
Loading: 
DEBUG: /home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/xla/third_party/repo.bzl:132:14: 
Warning: skipping import of repository 'llvm-raw' because it already exists.
Loading: 
Loading: 
DEBUG: /home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/tsl/third_party/repo.bzl:132:14: 
Warning: skipping import of repository 'nvtx_archive' because it already exists.
DEBUG: /home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/xla/third_party/repo.bzl:132:14: 
Warning: skipping import of repository 'jsoncpp_git' because it already exists.
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 
Loading: 2 packages loaded
Analyzing: target //jaxlib/tools:build_wheel (3 packages loaded, 0 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (36 packages loaded, 10 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (36 packages loaded, 10 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (36 packages loaded, 10 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (40 packages loaded, 211 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (52 packages loaded, 216 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (53 packages loaded, 216 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (53 packages loaded, 216 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (55 packages loaded, 216 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (84 packages loaded, 3140 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (98 packages loaded, 3333 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (113 packages loaded, 3375 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (134 packages loaded, 4022 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (137 packages loaded, 4467 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (144 packages loaded, 4862 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (151 packages loaded, 4948 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (154 packages loaded, 5182 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (182 packages loaded, 7473 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (196 packages loaded, 8839 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (201 packages loaded, 12364 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (221 packages loaded, 13560 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (225 packages loaded, 14932 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (225 packages loaded, 16505 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (226 packages loaded, 16827 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (227 packages loaded, 17184 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (228 packages loaded, 17496 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (230 packages loaded, 17592 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (231 packages loaded, 19698 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (231 packages loaded, 19698 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (234 packages loaded, 20571 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (234 packages loaded, 20571 targets configured)
Analyzing: target //jaxlib/tools:build_wheel (236 packages loaded, 22264 targets configured)
INFO: Repository pypi_numpy instantiated at:
  /nobackup/users/dsuarez/jax/WORKSPACE:29:13: in <toplevel>
  /home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/pypi/requirements.bzl:49:20: in install_deps
Repository rule whl_library defined at:
  /home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/rules_python/python/pip_install/pip_repository.bzl:697:30: in <toplevel>
ERROR: An error occurred during the fetch of repository 'pypi_numpy':
   Traceback (most recent call last):
        File "/home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/rules_python/python/pip_install/pip_repository.bzl", line 596, column 13, in _whl_library_impl
                fail("whl_library %s failed: %s (%s) error code: '%s'" % (rctx.attr.name, result.stdout, result.stderr, result.return_code))
Error in fail: whl_library pypi_numpy failed: Collecting numpy==2.0.0 (from -r /tmp/tmp0sft14zz (line 1))
  Downloading numpy-2.0.0.tar.gz (18.3 MB)
ssError: Command '['/home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/python_ppc64le-unknown-linux-gnu/bin/python3', '-m', 'pip', '--isolated', 'wheel', '--no-deps', '-r', '/tmp/tmp0sft14zz']' returned non-zero exit status 1.
) error code: '1'
ERROR: /nobackup/users/dsuarez/jax/WORKSPACE:29:13: fetching whl_library rule //external:pypi_numpy: Traceback (most recent call last):
        File "/home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/rules_python/python/pip_install/pip_repository.bzl", line 596, column 13, in _whl_library_impl
                fail("whl_library %s failed: %s (%s) error code: '%s'" % (rctx.attr.name, result.stdout, result.stderr, result.return_code))
Error in fail: whl_library pypi_numpy failed: Collecting numpy==2.0.0 (from -r /tmp/tmp0sft14zz (line 1))
  Downloading numpy-2.0.0.tar.gz (18.3 MB)
gs,
subprocess.CalledProcessError: Command '['/home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/python_ppc64le-unknown-linux-gnu/bin/python3', '-m', 'pip', '--isolated', 'wheel', '--no-deps', '-r', '/tmp/tmp0sft14zz']' returned non-zero exit status 1.
) error code: '1'
Analyzing: target //jaxlib/tools:build_wheel (236 packages loaded, 22264 targets configured)
ERROR: /home/dsuarez/.cache/bazel/_bazel_dsuarez/eb751057d1cbe464324d72cbeb188b2a/external/xla/third_party/py/numpy/BUILD:5:6: @xla//third_party/py/numpy:numpy depends on @pypi_numpy//:pkg in repository @pypi_numpy which failed to fetch. no such package '@pypi_numpy//': whl_library pypi_numpy failed: Collecting numpy==2.0.0 (from -r /tmp/tmp0sft14zz (line 1))
  Downloading numpy-2.0.0.tar.gz (18.3 MB)
olated', 'wheel', '--no-deps', '-r', '/tmp/tmp0sft14zz']' returned non-zero exit status 1.
) error code: '1'
ERROR: Analysis of target '//jaxlib/tools:build_wheel' failed; build aborted: 
INFO: Elapsed time: 147.692s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (236 packages loaded, 22264 targets configured)
ERROR: Build failed. Not running target
Traceback (most recent call last):
  File "/nobackup/users/dsuarez/jax/build/build.py", line 749, in <module>
    main()
  File "/nobackup/users/dsuarez/jax/build/build.py", line 700, in main
    shell(build_cpu_wheel_command)
  File "/nobackup/users/dsuarez/jax/build/build.py", line 45, in shell
    output = subprocess.check_output(cmd)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nobackup/users/dsuarez/miniconda3/envs/sum24r/lib/python3.11/subprocess.py", line 466, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nobackup/users/dsuarez/miniconda3/envs/sum24r/lib/python3.11/subprocess.py", line 571, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/nobackup/users/dsuarez/miniconda3/envs/sum24r/bin/bazel', 'run', '--verbose_failures=true', '//jaxlib/tools:build_wheel', '--', '--output_path=/nobackup/users/dsuarez/jax/dist', '--jaxlib_git_hash=24b42eed5ea4163ac5a3d4a8f8648545572933db', '--cpu=ppc64le']' r
eturned non-zero exit status 1.

Anyone know what is happening here?

EDIT: for anyone seeing this! Jaxlib 0.4.7 with CUDA 11.8/12.2 support should be available at https://anaconda.org/rocketce/jaxlib

bhuntsman commented 3 months ago

We are getting somewhere. Using Bazel 6.5.0 seems to be the key, and that's also what's indicated in the .bazelversion. The next problem I'm running into is that xla has a dependency on boringssl, which explicitly doesn't support ppc64le. It did in the past but a review of the commit history shows that the developers stripped out all the ppc64le code prior to the bazel integrations, which makes putting it back in rather challenging.

bhuntsman commented 3 months ago

If anyone wants it, I forked boringssl from the commit referenced in the XLA repo and added ppc64le support back to it. You can grab it from bhuntsman/boringssl:xla-ppc64le. I was able to get further using the following:

TMP=/tmp python build/build.py --noenable_mkl_dnn --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=11 --bazel_options=--override_repository=boringssl=/home/build/repos/boringssl

At least for me the next challenge is that the current jaxlib requres CUDA 11.8, and the "latest" for Ubuntu 20.04 on ppc64le is 11.1. jaxlib 0.4.25 still looks like it might support 11.1, and 0.4.25 is good enough for me, but other parts don't support ppc64le yet at that version.

bhuntsman commented 3 months ago

HEY GUYS! I got jax 0.4.25 to compile on ppc64le!

Here's what I did:

Now we have a few challenges to get this integrated. Chiefest of which is convincing the boringssl project to bring back ppc64le support. Originally I was looking at jax 0.4.25 as I thought perhaps it might support CUDA 11.1, but found that no it requires 11.8 as a minimum. Updating this to the latest jax, newer xla, and newer CUDA should be possible and I'll work on it in the coming days.

hawkinsp commented 3 months ago

Well, that's great! Certainly please upstream things.

I will note that NVIDIA has dropped ppc64 support as of CUDA 12.4, so at some point we're going to have to drop CUDA support on this architecture. We can't support it long term if NVIDIA does not.

For our own builds, we've stopped building or testing anything older than CUDA 12.1. So I'd encourage you to use CUDA 12.4.1, which seems to be the last release NVIDIA will make on that architecture, and that will keep you supported as long as possible by us.

bhuntsman commented 3 months ago

Related question here for anyone who has access to a ppc64le system. It looks like nvidia stopped creating installers for Ubuntu on ppc64le after CUDA 11.1, but continued to support RHEL 8 through CUDA 12.4. However, the runfile installer doesn't specifically indicate that it's only for RHEL, and in fact installs on Ubuntu just fine. Once installed, the kernel modules load fine and nvidia-smi indicates that it can see the cards. nvcc works as well. Unfortunately the resulting binaries do nothing. I had to symlink /usr/lib/powerpc64le-linux-gnu/libmpfr.so.4 -> libmpfr.so.6.0.2 to get cuda-gdb to run, but after loading it up and running a program, I get the following as soon as it tries to hit the GPU:

fatal: CUDBG_ERROR_INITIALIZATION_FAILURE (20): The CUDA driver initialization failed.
(cuda-gdb) info cuda devices
No CUDA devices.

I'd greatly prefer to stay on Ubuntu. Has anyone had any luck at anything like that, or is the only recourse here to be forced onto RHEL in order to have a more up to date CUDA?

Thanks!