jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

Can't build jaxlib in GH200 #21299

Open giladqm opened 4 months ago

giladqm commented 4 months ago

Description

I'm trying to run some code utilizing my GH200 without success. Unable to build jaxlib for my GPU.

System info (python version, jaxlib version, accelerator, etc.)

root@470c73980644:~/jax# nvidia-smi Sun May 19 12:13:00 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | On | | N/A 23C P0 62W / 900W | 5MiB / 97871MiB | N/A Default | | | | Enabled | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | MIG devices: | +------------------+----------------------------------+-----------+-----------------------+ | GPU GI CI MIG | Memory-Usage | Vol| Shared | | ID ID Dev | BAR1-Usage | SM Unc| CE ENC DEC OFA JPG | | | | ECC| | |==================+==================================+===========+=======================| | No MIG devices found | +-----------------------------------------------------------------------------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+ root@470c73980644:~/jax# nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:24:28_PDT_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0

the error i get: Error limit reached. 100 errors detected in the compilation of "external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc". Compilation terminated. Target //jaxlib/tools:build_gpu_plugin_wheel failed to build INFO: Elapsed time: 7.262s, Critical Path: 4.88s INFO: 73 processes: 73 internal. FAILED: Build did NOT complete successfully ERROR: Build failed. Not running target Traceback (most recent call last): File "/root/jax/build/build.py", line 733, in main() File "/root/jax/build/build.py", line 727, in main shell(build_pjrt_plugin_command) File "/root/jax/build/build.py", line 45, in shell output = subprocess.check_output(cmd) File "/usr/lib/python3.10/subprocess.py", line 421, in check_output return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, File "/usr/lib/python3.10/subprocess.py", line 526, in run raise CalledProcessError(retcode, process.args, subprocess.CalledProcessError: Command '['/usr/local/bin/bazel', 'run', '--verbose_failures=true', '//jaxlib/tools:build_gpu_plugin_wheel', '--', '--output_path=/root/jax/dist', '--jaxlib_git_hash=45a7c22e932fee257016bf0da1022be146ed6095', '--cpu=aarch64', '--cuda_version=12']' returned non-zero exit status 1.

superbobry commented 4 months ago

Can you share the compilation errors you're getting in external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc?

giladqm commented 4 months ago

Thanks for the quick reply...

include "absl/base/call_once.h"

include "absl/base/const_init.h"

include "absl/base/thread_annotations.h"

include "absl/container/node_hash_map.h"

include "absl/log/check.h"

include "absl/log/log.h"

include "absl/status/statusor.h"

include "absl/strings/string_view.h"

include "absl/synchronization/mutex.h"

include "absl/types/span.h"

include "third_party/gpus/cuda/include/cuda.h"

include "xla/stream_executor/cuda/cuda_asm_compiler.h"

include "xla/stream_executor/cuda/cuda_driver.h"

include "xla/stream_executor/device_memory.h"

include "xla/stream_executor/gpu/redzone_allocator_kernel.h"

include "xla/stream_executor/kernel.h"

include "xla/stream_executor/stream_executor_pimpl.h"

include "xla/stream_executor/typed_kernel_factory.h"

include "tsl/platform/statusor.h"

namespace stream_executor { // Maintains a cache of pointers to loaded kernels template static absl::StatusOr<TypedKernel> LoadKernelOrGetPtr( StreamExecutor executor, absl::string_view kernel_name, absl::string_view ptx, absl::Span cubin_data) { using KernelPtrCacheKey = std::tuple<CUcontext, absl::string_view, absl::string_view>;

static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit); static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) = *new absl::node_hash_map<KernelPtrCacheKey, TypedKernel>(); CUcontext current_context = cuda::CurrentContextOrDie(); KernelPtrCacheKey kernel_ptr_cache_key{current_context, kernel_name, ptx}; absl::MutexLock lock(&kernel_ptr_cache_mutex); auto it = kernel_ptr_cache.find(kernel_ptr_cache_key); if (it == kernel_ptr_cache.end()) { TF_ASSIGN_OR_RETURN(TypedKernel loaded, (TypedKernelFactory::Create( executor, kernel_name, ptx, cubin_data))); it = kernel_ptr_cache.emplace(kernel_ptr_cache_key, std::move(loaded)).first; }

CHECK(it != kernel_ptr_cache.end()); return &it->second; } // PTX blob for the function which checks that every byte in // input_buffer (length is buffer_length) is equal to redzone_pattern. // // On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr. // // Generated from: // global void redzone_checker(unsigned char input_buffer, // unsigned char redzone_pattern, // unsigned long long buffer_length, // int out_mismatched_ptr) { // unsigned long long idx = threadIdx.x + blockIdx.x blockDim.x; // if (idx >= buffer_length) return; // if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); // } // // Code must compile for the oldest GPU XLA may be compiled for. static const char redzone_checker_ptx = R"( .version 4.2 .target sm_30 .address_size 64

.visible .entry redzone_checker( .param .u64 input_buffer, .param .u8 redzone_pattern, .param .u64 buffer_length, .param .u64 out_mismatch_cnt_ptr ) { .reg .pred %p<3>; .reg .b16 %rs<3>; .reg .b32 %r<6>; .reg .b64 %rd<8>;

ld.param.u64 %rd6, [buffer_length]; ld.param.u64 %rd4, [input_buffer]; cvta.to.global.u64 %rd2, %rd4; add.s64 %rd7, %rd2, %rd3; ld.global.u8 %rs2, [%rd7]; setp.eq.s16 %p2, %rs2, %rs1; @%p2 bra LBB6_3; ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; ld.param.u8 %rs1, [redzone_pattern]; ld.param.u64 %rd4, [input_buffer]; cvta.to.global.u64 %rd2, %rd4; add.s64 %rd7, %rd2, %rd3; ld.global.u8 %rs2, [%rd7]; setp.eq.s16 %p2, %rs2, %rs1; @%p2 bra LBB6_3; ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; cvta.to.global.u64 %rd1, %rd5; atom.global.add.u32 %r5, [%rd1], 1; LBB6_3: ret; } )";

absl::StatusOr<const ComparisonKernel> GetComparisonKernel( StreamExecutor executor, GpuAsmOpts gpu_asm_opts) { absl::Span compiled_ptx = {}; absl::StatusOr<absl::Span> compiled_ptx_or = CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx, gpu_asm_opts); if (compiled_ptx_or.ok()) { compiled_ptx = compiled_ptx_or.value(); } else { static absl::once_flag ptxas_not_found_logged; absl::call_once(ptxas_not_found_logged, [&]() { LOG(WARNING) << compiled_ptx_or.status() << "\nRelying on driver to perform ptx compilation. " << "\nModify $PATH to customize ptxas location." << "\nThis message will be only logged once."; }); }

return LoadKernelOrGetPtr<DeviceMemory, uint8_t, uint64_t, DeviceMemory>( executor, "redzone_checker", redzone_checker_ptx, compiled_ptx); } } // namespace stream_executor .reg .b16 %rs<3>;

superbobry commented 4 months ago

This snippet doesn't contain any compilation errors AFAICT. Can you upload the output of the compiler to a gist?

giladqm commented 4 months ago

I'm sorry but I don't understand the request. Can you be more specific and include the linux terminal commands you want me to run?

superbobry commented 4 months ago

The message you posted initially

Error limit reached.
100 errors detected in the compilation of "external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc".

is usually preceded by compilation error messages, describing what went wrong while compiling jaxlib. If you upload the full output of build.py, that would include the error messages as well.

mjsML commented 4 months ago

cc : @nouiz

giladqm commented 4 months ago

This is what I get now (I'm in a different docker imae now):

(base) root@8c1c1dd5a763:~/jax# python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12

 _   _  __  __
| | / \ \ \/ /

| |/ \ \ / | |_| / \/ \ \// \//_\

Bazel binary path: ./bazel-6.5.0-linux-arm64 Bazel version: 6.5.0 Python binary path: /root/miniconda3/bin/python3 Python version: 3.12 Use clang: no MKL-DNN enabled: yes Target CPU: aarch64 Target CPU features: release CUDA enabled: yes NCCL enabled: yes ROCm enabled: no

Building XLA and installing it in the jaxlib source tree... ./bazel-6.5.0-linux-arm64 run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=/root/jax/dist --jaxlib_git_hash=ffdb9bb0b0755e66f55995cafa2cf0946ed66598 --cpu=aarch64 --skip_gpu_kernels INFO: Options provided by the client: Inherited 'common' options: --isatty=0 --terminal_columns=80 INFO: Reading rc options for 'run' from /root/jax/.bazelrc: Inherited 'common' options: --experimental_repo_remote_exec INFO: Reading rc options for 'run' from /root/jax/.bazelrc: Inherited 'build' options: --nocheck_visibility --apple_platform_type=macos --macos_minimum_os=10.14 --announce_rc --define open_source_build=true --spawn_strategy=standalone --enable_platform_specific_config --experimental_cc_shared_library --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --define=tsl_link_protobuf=true -c opt --config=short_logs --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. --@xla//xla/python:enable_gpu=false INFO: Reading rc options for 'run' from /root/jax/.jax_configure.bazelrc: Inherited 'build' options: --strategy=Genrule=standalone --config=mkl_open_source_only --config=cuda --config=cuda_plugin --repo_env HERMETIC_PYTHON_VERSION=3.12 INFO: Found applicable config definition build:short_logs in file /root/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING INFO: Found applicable config definition build:mkl_open_source_only in file /root/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1 INFO: Found applicable config definition build:cuda in file /root/jax/.bazelrc: --repo_env TF_NEED_CUDA=1 --repo_env TF_NCCL_USE_STUB=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda --@xla//xla/python:enable_gpu=true --@xla//xla/python:jax_cuda_pip_rpaths=true --define=xla_python_enable_gpu=true --linkopt=-Wl,--disable-new-dtags INFO: Found applicable config definition build:cuda_plugin in file /root/jax/.bazelrc: --@xla//xla/python:enable_gpu=false --define=xla_python_enable_gpu=false INFO: Found applicable config definition build:linux in file /root/jax/.bazelrc: --config=posix --copt=-Wno-unknown-warning-option --copt=-Wno-stringop-truncation --copt=-Wno-array-parameter INFO: Found applicable config definition build:posix in file /root/jax/.bazelrc: --copt=-fvisibility=hidden --copt=-Wno-sign-compare --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 Loading: INFO: Repository local_config_cuda instantiated at: /root/jax/WORKSPACE:45:15: in /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/xla/workspace2.bzl:121:19: in workspace /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/workspace2.bzl:601:19: in workspace /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/workspace2.bzl:72:19: in _tf_toolchains Repository rule cuda_configure defined at: /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl:1542:33: in ERROR: An error occurred during the fetch of repository 'local_config_cuda': Traceback (most recent call last): File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 1491, column 38, in _cuda_autoconf_impl _create_local_cuda_repository(repository_ctx) File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 1040, column 35, in _create_local_cuda_repository cuda_config = _get_cuda_config(repository_ctx) File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 716, column 30, in _get_cuda_config config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 693, column 26, in find_cuda_config exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries) File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/remote_config/common.bzl", line 230, column 13, in execute fail( Error in fail: Repository command failed Could not find any cuda.h matching version '' in any subdirectory: '' 'include' 'include/cuda' 'include/-linux-gnu' 'extras/CUPTI/include' 'include/cuda/CUPTI' 'local/cuda/extras/CUPTI/include' 'targets/x86_64-linux/include' of: '/lib' '/usr' '/usr/lib/aarch64-linux-gnu' '/usr/lib/aarch64-linux-gnu/libfakeroot' ERROR: /root/jax/WORKSPACE:45:15: fetching cuda_configure rule //external:local_config_cuda: Traceback (most recent call last): File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 1491, column 38, in _cuda_autoconf_impl _create_local_cuda_repository(repository_ctx) File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 1040, column 35, in _create_local_cuda_repository cuda_config = _get_cuda_config(repository_ctx) File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 716, column 30, in _get_cuda_config config = find_cuda_config(repository_ctx, ["cuda", "cudnn"]) File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/gpus/cuda_configure.bzl", line 693, column 26, in find_cuda_config exec_result = execute(repository_ctx, [python_bin, repository_ctx.attr._find_cuda_config] + cuda_libraries) File "/root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/third_party/remote_config/common.bzl", line 230, column 13, in execute fail( Error in fail: Repository command failed Could not find any cuda.h matching version '' in any subdirectory: '' 'include' 'include/cuda' 'include/-linux-gnu' 'extras/CUPTI/include' 'include/cuda/CUPTI' 'local/cuda/extras/CUPTI/include' 'targets/x86_64-linux/include' of: '/lib' '/usr' '/usr/lib/aarch64-linux-gnu' '/usr/lib/aarch64-linux-gnu/libfakeroot' INFO: Repository rules_cc instantiated at: /root/jax/WORKSPACE:48:15: in /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/xla/workspace1.bzl:12:19: in workspace /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/tsl/workspace1.bzl:30:14: in workspace /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/com_github_grpc_grpc/bazel/grpc_deps.bzl:158:21: in grpc_deps Repository rule http_archive defined at: /root/.cache/bazel/_bazel_root/deb80d6610824a92deeac7b7fd0f3e3c/external/bazel_tools/tools/build_defs/repo/http.bzl:372:31: in ERROR: Skipping '@xla//xla/python:enable_gpu': no such package '@local_config_cuda//cuda': Repository command failed Could not find any cuda.h matching version '' in any subdirectory: '' 'include' 'include/cuda' 'include/-linux-gnu' 'extras/CUPTI/include' 'include/cuda/CUPTI' 'local/cuda/extras/CUPTI/include' 'targets/x86_64-linux/include' of: '/lib' '/usr' '/usr/lib/aarch64-linux-gnu' '/usr/lib/aarch64-linux-gnu/libfakeroot' WARNING: Target pattern parsing failed. ERROR: @xla//xla/python:enable_gpu :: Error loading option @xla//xla/python:enable_gpu: no such package '@local_config_cuda//cuda': Repository command failed Could not find any cuda.h matching version '' in any subdirectory: '' 'include' 'include/cuda' 'include/-linux-gnu' 'extras/CUPTI/include' 'include/cuda/CUPTI' 'local/cuda/extras/CUPTI/include' 'targets/x86_64-linux/include' of: '/lib' '/usr' '/usr/lib/aarch64-linux-gnu' '/usr/lib/aarch64-linux-gnu/libfakeroot' Traceback (most recent call last): File "/root/jax/build/build.py", line 733, in main() File "/root/jax/build/build.py", line 699, in main shell(build_cpu_wheel_command) File "/root/jax/build/build.py", line 45, in shell output = subprocess.check_output(cmd) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/lib/python3.12/subprocess.py", line 466, in check_output return run(*popenargs, stdout=PIPE, timeout=timeout, check=True, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/root/miniconda3/lib/python3.12/subprocess.py", line 571, in run raise CalledProcessError(retcode, process.args, subprocess.CalledProcessError: Command '['./bazel-6.5.0-linux-arm64', 'run', '--verbose_failures=true', '//jaxlib/tools:build_wheel', '--', '--output_path=/root/jax/dist', '--jaxlib_git_hash=ffdb9bb0b0755e66f55995cafa2cf0946ed66598', '--cpu=aarch64', '--skip_gpu_kernels']' returned non-zero exit status 2.

superbobry commented 4 months ago

Okay, from this it looks like your CUDA installation is missing development headers:

Could not find any cuda.h matching version '' in any subdirectory:
nouiz commented 4 months ago

JAX-Toolbox has nightly JAX container for ARM: https://github.com/NVIDIA/JAX-Toolbox For example: ghcr.io/nvidia/jax:jax for the latest nightly.

If you want to build JAX yourself, this container already contain cuda:

docker pull nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04

I'm mostly always using those 2 containers for development in JAX.

giladqm commented 4 months ago

@nouiz thanks, I tried those two options without success. I'm using a GH200 and I'm trying to use jax with the GPU, but it always fails.

hawkinsp commented 4 months ago

I'll also note that we (JAX) release CUDA arm wheels on pypi which should just work on GH200. Try:

pip install jax jaxlib jax-cuda12-plugin jax-cuda12-pjrt

(The more usual pip install jax[cuda12] won't work because NVIDIA doesn't release ARM wheels of CUDA, last I checked.)

nouiz commented 4 months ago

We released ARM wheel last week. But it isn't tested. So let's try jax[cuda12]

docker run -it --gpus all ubuntu
apt-get update; apt-get install -y python3-pip python3.12-venv
python3 -m venv path/to/venv
source path/to/venv/bin/activate
pip install jax[cuda12] # works
python3 -c "import jax; jax.numpy.zeros(3)" # fail with cudnn init error.

This installed cudnn 9.1.1. cudnn 8 isn't supported on GraceHopper to my knowledge. @hawkinsp Does the JAX wheel for ARM are also build with cudnn 8? Any idea when the cudnn 9 version can be created?

giladqm commented 4 months ago

Thanks for the update, I'll check it right away

giladqm commented 4 months ago

@hawkinsp After executing

pip install jax jaxlib jax-cuda12-plugin jax-cuda12-pjrt I'm trying to run the following code:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Specify the index of the GPU you want to use

import jax
import jax.numpy as jnp

def main():
    # Explicitly place arrays on GPU using jax.device_put
    gpu_device = jax.devices("gpu")[0]  # Use the first GPU
    a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
    b = jax.random.normal(jax.random.PRNGKey(1), (size, size))
    a_gpu = jax.device_put(a, device=gpu_device)
    b_gpu = jax.device_put(b, device=gpu_device)

    # Run matrix multiplication on GPU
    result = jnp.dot(a_gpu, b_gpu)

    # Print the result
    print("Result of matrix multiplication:")
    print(result)

if __name__ == "__main__":
    main()

but get RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

giladqm commented 4 months ago

@nouiz does this mean I can't yet run jax with GPU acceleration on GH200?

nouiz commented 4 months ago

It is possible. Can you give the exact command line you use to start the docker container? What is the output of nvidia-smi in it? Can you try the jax container we provide? It should work and won't ask you to compile JAX.

giladqm commented 4 months ago

Last week I used the docker jax:jax here: https://github.com/NVIDIA/JAX-Toolbox I don't mind trying it again. or do u mean nvcr.io/nvidia/jax:24.04-maxtext-py3 (from here: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax)?

giladqm commented 4 months ago

Last week I used the docker jax:jax here: https://github.com/NVIDIA/JAX-Toolbox I don't mind trying it again. or do u mean nvcr.io/nvidia/jax:24.04-maxtext-py3 (from here: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax)?

@nouiz

giladqm commented 4 months ago
$ docker pull nvcr.io/nvidia/jax:24.04-maxtext-py3
24.04-maxtext-py3: Pulling from nvidia/jax
no matching manifest for linux/arm64/v8 in the manifest list entries

@nouiz

giladqm commented 4 months ago
gilad@gracehopper:~$ docker pull ghcr.io/nvidia/jax:jax 
jax: Pulling from nvidia/jax

works @nouiz

giladqm commented 4 months ago
gilad@gracehopper:~$ docker pull ghcr.io/nvidia/jax:jax 
jax: Pulling from nvidia/jax

works @nouiz

gilad@gracehopper:~$ docker run -it --gpus all ghcr.io/nvidia/jax:jax

========== == CUDA ==

CUDA Version 12.4.1

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License. By pulling and using the container, you accept the terms and conditions of this license: https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

WARNING: Your shm is currenly less than 1GB. This may cause SIGBUS errors. To avoid this problem, you can manually set the shm size in docker with:

docker run ... --shm-size=1g ...

root@f85914843395:/# nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:24:28_PDT_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0 root@f85914843395:/# nvidia-smi Tue May 21 17:27:31 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | On | | N/A 23C P0 62W / 900W | 6MiB / 97871MiB | N/A Default | | | | Enabled | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | MIG devices: | +------------------+----------------------------------+-----------+-----------------------+ | GPU GI CI MIG | Memory-Usage | Vol| Shared | | ID ID Dev | BAR1-Usage | SM Unc| CE ENC DEC OFA JPG | | | | ECC| | |==================+==================================+===========+=======================| | No MIG devices found | +-----------------------------------------------------------------------------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+ root@f85914843395:/#

giladqm commented 4 months ago

@nouiz I still get: root@f85914843395:~# python jax_program.py 2024-05-21 18:06:49.257055: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_SYSTEM_NOT_READY: system not yet initialized Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 679, in backends backend = _init_backend(platform) File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 761, in _init_backend backend = registration.factory() File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 509, in factory return xla_client.make_c_api_client(plugin_name, options, None) File "/usr/local/lib/python3.10/dist-packages/jaxlib/xla_client.py", line 190, in make_c_api_client return _xla.get_c_api_client(plugin_name, options, distributed_client) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/root/jax_program.py", line 23, in main() File "/root/jax_program.py", line 9, in main gpu_device = jax.devices("gpu")[0] # Use the first GPU File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 872, in devices return get_backend(backend).devices() File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 806, in get_backend return _get_backend_uncached(platform) File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 786, in _get_backend_uncached bs = backends() File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 695, in backends raise RuntimeError(err_msg) RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

nouiz commented 4 months ago

From the output of nvidia-smi, the issues seem to be that MIG is enabled, but no MIG "instance" is created. If exact, that would make all software fail on that node. Can you ask your admins how they setup MIG and how to have a MIG instance created?

giladqm commented 4 months ago

@nouiz u we're right! My colleague fixed the issue:

(base) nikola@gracehopper:~$  sudo nvidia-smi mig -lgi
+-------------------------------------------------------+
| GPU instances:                                        |
| GPU   Name             Profile  Instance   Placement  |
|                          ID       ID       Start:Size |
|=======================================================|
|   0  MIG 7g.96gb          0        0          0:8     |
+-------------------------------------------------------+
(base) nikola@gracehopper:~$ echo $CUDA_VISIBLE_DEVICES

(base) nikola@gracehopper:~$ nvidia-smi -L
GPU 0: NVIDIA GH200 480GB (UUID: GPU-d8731c65-c898-919e-74c9-286b27400dac)
  MIG 7g.96gb     Device  0: (UUID: MIG-7baedeb1-c0d7-53ba-9926-2e341a42b470)

But now when I run the following code:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Specify the index of the GPU you want to use

import jax
import jax.numpy as jnp

# Let's define a simple matrix multiplication function
def matmul_on_gpu(a, b):
    return jnp.dot(a, b)

# Main function to demonstrate GPU acceleration
def main():
    # Create some random matrices
    size = 1000
    a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
    b = jax.random.normal(jax.random.PRNGKey(1), (size, size))

    # Run matrix multiplication on GPU
    result = matmul_on_gpu(a, b)

    # Print the result
    print("Result of matrix multiplication:")
    print(result)

if __name__ == "__main__":
    main()

I get the following error:

root@ce139f7a5d68:~# python jax_program.py 
2024-05-22 19:00:07.301687: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-05-22 19:00:07.301822: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 100780867584 bytes free, 101468602368 bytes total.
2024-05-22 19:00:07.302174: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-05-22 19:00:07.302277: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 100780867584 bytes free, 101468602368 bytes total.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/root/jax_program.py", line 26, in <module>
    main()
  File "/root/jax_program.py", line 15, in main
    a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 240, in PRNGKey
    return _return_prng_keys(True, _key('PRNGKey', seed, impl))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 202, in _key
    return prng.random_seed(seed, impl=impl)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/prng.py", line 595, in random_seed
    seeds_arr = jnp.asarray(np.int64(seeds))
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 2217, in asarray
    return array(a, dtype=dtype, copy=bool(copy), order=order)  # type: ignore
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 2172, in array
    out_array: Array = lax_internal._convert_element_type(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py", line 560, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 444, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 447, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 935, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

@nouiz is this the cuDNN problem you mentioned? what can I do to check it?

giladqm commented 4 months ago

Yes I have cuDNN version 9.1.0.

hawkinsp commented 4 months ago

Yes, you need to downgrade to CUDNN 8.9 for now. JAX doesn't yet release with CUDNN 9.

nouiz commented 4 months ago

Before downgrading, which container do you use and how JAX was installed? If you use the JAX-Toolbox jax container, you have a good combination of JAX (nightly), cudnn (9.1.1), and CUDA 12.4.1.

Did you try with the JAX-toolbox container without setting CUDA_VISIBLE_DEVICES? Why do you try to set it? If there is only 1 GPU, normally JAX will just find and use it. So you don't need to set it. The MIG listing isn't the same as normal GPU. Also, you can't do multi-gpu across MIGs.