Open giladqm opened 4 months ago
Can you share the compilation errors you're getting in external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc?
namespace stream_executor {
// Maintains a cache of pointers to loaded kernels
template
static absl::Mutex kernel_ptr_cache_mutex(absl::kConstInit);
static auto& kernel_ptr_cache ABSL_GUARDED_BY(kernel_ptr_cache_mutex) =
*new absl::node_hash_map<KernelPtrCacheKey, TypedKernel
CHECK(it != kernel_ptr_cache.end()); return &it->second; } // PTX blob for the function which checks that every byte in // input_buffer (length is buffer_length) is equal to redzone_pattern. // // On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr. // // Generated from: // global void redzone_checker(unsigned char input_buffer, // unsigned char redzone_pattern, // unsigned long long buffer_length, // int out_mismatched_ptr) { // unsigned long long idx = threadIdx.x + blockIdx.x blockDim.x; // if (idx >= buffer_length) return; // if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1); // } // // Code must compile for the oldest GPU XLA may be compiled for. static const char redzone_checker_ptx = R"( .version 4.2 .target sm_30 .address_size 64
.visible .entry redzone_checker( .param .u64 input_buffer, .param .u8 redzone_pattern, .param .u64 buffer_length, .param .u64 out_mismatch_cnt_ptr ) { .reg .pred %p<3>; .reg .b16 %rs<3>; .reg .b32 %r<6>; .reg .b64 %rd<8>;
ld.param.u64 %rd6, [buffer_length]; ld.param.u64 %rd4, [input_buffer]; cvta.to.global.u64 %rd2, %rd4; add.s64 %rd7, %rd2, %rd3; ld.global.u8 %rs2, [%rd7]; setp.eq.s16 %p2, %rs2, %rs1; @%p2 bra LBB6_3; ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; ld.param.u8 %rs1, [redzone_pattern]; ld.param.u64 %rd4, [input_buffer]; cvta.to.global.u64 %rd2, %rd4; add.s64 %rd7, %rd2, %rd3; ld.global.u8 %rs2, [%rd7]; setp.eq.s16 %p2, %rs2, %rs1; @%p2 bra LBB6_3; ld.param.u64 %rd5, [out_mismatch_cnt_ptr]; cvta.to.global.u64 %rd1, %rd5; atom.global.add.u32 %r5, [%rd1], 1; LBB6_3: ret; } )";
absl::StatusOr<const ComparisonKernel> GetComparisonKernel(
StreamExecutor executor, GpuAsmOpts gpu_asm_opts) {
absl::Span
return LoadKernelOrGetPtr<DeviceMemory
This snippet doesn't contain any compilation errors AFAICT. Can you upload the output of the compiler to a gist?
I'm sorry but I don't understand the request. Can you be more specific and include the linux terminal commands you want me to run?
The message you posted initially
Error limit reached.
100 errors detected in the compilation of "external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc".
is usually preceded by compilation error messages, describing what went wrong while compiling jaxlib. If you upload the full output of build.py
, that would include the error messages as well.
cc : @nouiz
This is what I get now (I'm in a different docker imae now):
(base) root@8c1c1dd5a763:~/jax# python3 build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12
_ _ __ __
| | / \ \ \/ /
| |/ \ \ / | |_| / \/ \ \// \//_\
Bazel binary path: ./bazel-6.5.0-linux-arm64 Bazel version: 6.5.0 Python binary path: /root/miniconda3/bin/python3 Python version: 3.12 Use clang: no MKL-DNN enabled: yes Target CPU: aarch64 Target CPU features: release CUDA enabled: yes NCCL enabled: yes ROCm enabled: no
Building XLA and installing it in the jaxlib source tree...
./bazel-6.5.0-linux-arm64 run --verbose_failures=true //jaxlib/tools:build_wheel -- --output_path=/root/jax/dist --jaxlib_git_hash=ffdb9bb0b0755e66f55995cafa2cf0946ed66598 --cpu=aarch64 --skip_gpu_kernels
INFO: Options provided by the client:
Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'run' from /root/jax/.bazelrc:
Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'run' from /root/jax/.bazelrc:
Inherited 'build' options: --nocheck_visibility --apple_platform_type=macos --macos_minimum_os=10.14 --announce_rc --define open_source_build=true --spawn_strategy=standalone --enable_platform_specific_config --experimental_cc_shared_library --define=no_aws_support=true --define=no_gcp_support=true --define=no_hdfs_support=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --define=tsl_link_protobuf=true -c opt --config=short_logs --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. --@xla//xla/python:enable_gpu=false
INFO: Reading rc options for 'run' from /root/jax/.jax_configure.bazelrc:
Inherited 'build' options: --strategy=Genrule=standalone --config=mkl_open_source_only --config=cuda --config=cuda_plugin --repo_env HERMETIC_PYTHON_VERSION=3.12
INFO: Found applicable config definition build:short_logs in file /root/jax/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:mkl_open_source_only in file /root/jax/.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
INFO: Found applicable config definition build:cuda in file /root/jax/.bazelrc: --repo_env TF_NEED_CUDA=1 --repo_env TF_NCCL_USE_STUB=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda --@xla//xla/python:enable_gpu=true --@xla//xla/python:jax_cuda_pip_rpaths=true --define=xla_python_enable_gpu=true --linkopt=-Wl,--disable-new-dtags
INFO: Found applicable config definition build:cuda_plugin in file /root/jax/.bazelrc: --@xla//xla/python:enable_gpu=false --define=xla_python_enable_gpu=false
INFO: Found applicable config definition build:linux in file /root/jax/.bazelrc: --config=posix --copt=-Wno-unknown-warning-option --copt=-Wno-stringop-truncation --copt=-Wno-array-parameter
INFO: Found applicable config definition build:posix in file /root/jax/.bazelrc: --copt=-fvisibility=hidden --copt=-Wno-sign-compare --cxxopt=-std=c++17 --host_cxxopt=-std=c++17
Loading:
INFO: Repository local_config_cuda instantiated at:
/root/jax/WORKSPACE:45:15: in
Okay, from this it looks like your CUDA installation is missing development headers:
Could not find any cuda.h matching version '' in any subdirectory:
JAX-Toolbox has nightly JAX container for ARM: https://github.com/NVIDIA/JAX-Toolbox
For example: ghcr.io/nvidia/jax:jax
for the latest nightly.
If you want to build JAX yourself, this container already contain cuda:
docker pull nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
I'm mostly always using those 2 containers for development in JAX.
@nouiz thanks, I tried those two options without success. I'm using a GH200 and I'm trying to use jax with the GPU, but it always fails.
I'll also note that we (JAX) release CUDA arm wheels on pypi which should just work on GH200. Try:
pip install jax jaxlib jax-cuda12-plugin jax-cuda12-pjrt
(The more usual pip install jax[cuda12]
won't work because NVIDIA doesn't release ARM wheels of CUDA, last I checked.)
We released ARM wheel last week. But it isn't tested. So let's try jax[cuda12]
docker run -it --gpus all ubuntu
apt-get update; apt-get install -y python3-pip python3.12-venv
python3 -m venv path/to/venv
source path/to/venv/bin/activate
pip install jax[cuda12] # works
python3 -c "import jax; jax.numpy.zeros(3)" # fail with cudnn init error.
This installed cudnn 9.1.1. cudnn 8 isn't supported on GraceHopper to my knowledge. @hawkinsp Does the JAX wheel for ARM are also build with cudnn 8? Any idea when the cudnn 9 version can be created?
Thanks for the update, I'll check it right away
@hawkinsp After executing
pip install jax jaxlib jax-cuda12-plugin jax-cuda12-pjrt
I'm trying to run the following code:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Specify the index of the GPU you want to use
import jax
import jax.numpy as jnp
def main():
# Explicitly place arrays on GPU using jax.device_put
gpu_device = jax.devices("gpu")[0] # Use the first GPU
a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
b = jax.random.normal(jax.random.PRNGKey(1), (size, size))
a_gpu = jax.device_put(a, device=gpu_device)
b_gpu = jax.device_put(b, device=gpu_device)
# Run matrix multiplication on GPU
result = jnp.dot(a_gpu, b_gpu)
# Print the result
print("Result of matrix multiplication:")
print(result)
if __name__ == "__main__":
main()
but get
RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
@nouiz does this mean I can't yet run jax with GPU acceleration on GH200?
It is possible. Can you give the exact command line you use to start the docker container? What is the output of nvidia-smi in it? Can you try the jax container we provide? It should work and won't ask you to compile JAX.
Last week I used the docker jax:jax here: https://github.com/NVIDIA/JAX-Toolbox I don't mind trying it again. or do u mean nvcr.io/nvidia/jax:24.04-maxtext-py3 (from here: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax)?
Last week I used the docker jax:jax here: https://github.com/NVIDIA/JAX-Toolbox I don't mind trying it again. or do u mean nvcr.io/nvidia/jax:24.04-maxtext-py3 (from here: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax)?
@nouiz
$ docker pull nvcr.io/nvidia/jax:24.04-maxtext-py3
24.04-maxtext-py3: Pulling from nvidia/jax
no matching manifest for linux/arm64/v8 in the manifest list entries
@nouiz
gilad@gracehopper:~$ docker pull ghcr.io/nvidia/jax:jax
jax: Pulling from nvidia/jax
works @nouiz
gilad@gracehopper:~$ docker pull ghcr.io/nvidia/jax:jax jax: Pulling from nvidia/jax
works @nouiz
gilad@gracehopper:~$ docker run -it --gpus all ghcr.io/nvidia/jax:jax
CUDA Version 12.4.1
Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
This container image and its contents are governed by the NVIDIA Deep Learning Container License. By pulling and using the container, you accept the terms and conditions of this license: https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license
A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.
WARNING: Your shm is currenly less than 1GB. This may cause SIGBUS errors. To avoid this problem, you can manually set the shm size in docker with:
docker run ... --shm-size=1g ...
root@f85914843395:/# nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:24:28_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0
root@f85914843395:/# nvidia-smi
Tue May 21 17:27:31 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | On |
| N/A 23C P0 62W / 900W | 6MiB / 97871MiB | N/A Default |
| | | Enabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+ | MIG devices: | +------------------+----------------------------------+-----------+-----------------------+ | GPU GI CI MIG | Memory-Usage | Vol| Shared | | ID ID Dev | BAR1-Usage | SM Unc| CE ENC DEC OFA JPG | | | | ECC| | |==================+==================================+===========+=======================| | No MIG devices found | +-----------------------------------------------------------------------------------------+
+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+ root@f85914843395:/#
@nouiz I still get: root@f85914843395:~# python jax_program.py 2024-05-21 18:06:49.257055: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:282] failed call to cuInit: CUDA_ERROR_SYSTEM_NOT_READY: system not yet initialized Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 679, in backends backend = _init_backend(platform) File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 761, in _init_backend backend = registration.factory() File "/usr/local/lib/python3.10/dist-packages/jax/_src/xla_bridge.py", line 509, in factory return xla_client.make_c_api_client(plugin_name, options, None) File "/usr/local/lib/python3.10/dist-packages/jaxlib/xla_client.py", line 190, in make_c_api_client return _xla.get_c_api_client(plugin_name, options, distributed_client) jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: No visible GPU devices.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/root/jax_program.py", line 23, in
From the output of nvidia-smi, the issues seem to be that MIG is enabled, but no MIG "instance" is created. If exact, that would make all software fail on that node. Can you ask your admins how they setup MIG and how to have a MIG instance created?
@nouiz u we're right! My colleague fixed the issue:
(base) nikola@gracehopper:~$ sudo nvidia-smi mig -lgi
+-------------------------------------------------------+
| GPU instances: |
| GPU Name Profile Instance Placement |
| ID ID Start:Size |
|=======================================================|
| 0 MIG 7g.96gb 0 0 0:8 |
+-------------------------------------------------------+
(base) nikola@gracehopper:~$ echo $CUDA_VISIBLE_DEVICES
(base) nikola@gracehopper:~$ nvidia-smi -L
GPU 0: NVIDIA GH200 480GB (UUID: GPU-d8731c65-c898-919e-74c9-286b27400dac)
MIG 7g.96gb Device 0: (UUID: MIG-7baedeb1-c0d7-53ba-9926-2e341a42b470)
But now when I run the following code:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Specify the index of the GPU you want to use
import jax
import jax.numpy as jnp
# Let's define a simple matrix multiplication function
def matmul_on_gpu(a, b):
return jnp.dot(a, b)
# Main function to demonstrate GPU acceleration
def main():
# Create some random matrices
size = 1000
a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
b = jax.random.normal(jax.random.PRNGKey(1), (size, size))
# Run matrix multiplication on GPU
result = matmul_on_gpu(a, b)
# Print the result
print("Result of matrix multiplication:")
print(result)
if __name__ == "__main__":
main()
I get the following error:
root@ce139f7a5d68:~# python jax_program.py
2024-05-22 19:00:07.301687: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-05-22 19:00:07.301822: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 100780867584 bytes free, 101468602368 bytes total.
2024-05-22 19:00:07.302174: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:474] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-05-22 19:00:07.302277: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:478] Memory usage: 100780867584 bytes free, 101468602368 bytes total.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/root/jax_program.py", line 26, in <module>
main()
File "/root/jax_program.py", line 15, in main
a = jax.random.normal(jax.random.PRNGKey(0), (size, size))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 240, in PRNGKey
return _return_prng_keys(True, _key('PRNGKey', seed, impl))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/random.py", line 202, in _key
return prng.random_seed(seed, impl=impl)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/prng.py", line 595, in random_seed
seeds_arr = jnp.asarray(np.int64(seeds))
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 2217, in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order) # type: ignore
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 2172, in array
out_array: Array = lax_internal._convert_element_type(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py", line 560, in _convert_element_type
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 444, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 447, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 935, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
@nouiz is this the cuDNN problem you mentioned? what can I do to check it?
Yes I have cuDNN version 9.1.0.
Yes, you need to downgrade to CUDNN 8.9 for now. JAX doesn't yet release with CUDNN 9.
Before downgrading, which container do you use and how JAX was installed? If you use the JAX-Toolbox jax container, you have a good combination of JAX (nightly), cudnn (9.1.1), and CUDA 12.4.1.
Did you try with the JAX-toolbox container without setting CUDA_VISIBLE_DEVICES? Why do you try to set it? If there is only 1 GPU, normally JAX will just find and use it. So you don't need to set it. The MIG listing isn't the same as normal GPU. Also, you can't do multi-gpu across MIGs.
Description
I'm trying to run some code utilizing my GH200 without success. Unable to build jaxlib for my GPU.
System info (python version, jaxlib version, accelerator, etc.)
root@470c73980644:~/jax# nvidia-smi Sun May 19 12:13:00 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | On | | N/A 23C P0 62W / 900W | 5MiB / 97871MiB | N/A Default | | | | Enabled | +-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+ | MIG devices: | +------------------+----------------------------------+-----------+-----------------------+ | GPU GI CI MIG | Memory-Usage | Vol| Shared | | ID ID Dev | BAR1-Usage | SM Unc| CE ENC DEC OFA JPG | | | | ECC| | |==================+==================================+===========+=======================| | No MIG devices found | +-----------------------------------------------------------------------------------------+
+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+ root@470c73980644:~/jax# nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Mar_28_02:24:28_PDT_2024 Cuda compilation tools, release 12.4, V12.4.131 Build cuda_12.4.r12.4/compiler.34097967_0
the error i get: Error limit reached. 100 errors detected in the compilation of "external/xla/xla/stream_executor/gpu/redzone_allocator_kernel_cuda.cc". Compilation terminated. Target //jaxlib/tools:build_gpu_plugin_wheel failed to build INFO: Elapsed time: 7.262s, Critical Path: 4.88s INFO: 73 processes: 73 internal. FAILED: Build did NOT complete successfully ERROR: Build failed. Not running target Traceback (most recent call last): File "/root/jax/build/build.py", line 733, in
main()
File "/root/jax/build/build.py", line 727, in main
shell(build_pjrt_plugin_command)
File "/root/jax/build/build.py", line 45, in shell
output = subprocess.check_output(cmd)
File "/usr/lib/python3.10/subprocess.py", line 421, in check_output
return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
File "/usr/lib/python3.10/subprocess.py", line 526, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/usr/local/bin/bazel', 'run', '--verbose_failures=true', '//jaxlib/tools:build_gpu_plugin_wheel', '--', '--output_path=/root/jax/dist', '--jaxlib_git_hash=45a7c22e932fee257016bf0da1022be146ed6095', '--cpu=aarch64', '--cuda_version=12']' returned non-zero exit status 1.