Open PhilipVinc opened 2 months ago
These are typically issues with the hipblaslt autotuning not being able to run its autotuning kernels correctly. I have been seeing a lot of these lately and will be opening an issue to the XLA folks to look into them more.
If you could gather some information for me to forward on to them that would be a big help.
1) Can you confirm that the issue is only present with x64 flag set?
2) What card model were you running this against? (MI100 / MI250 / etc)
3) Can you try running the same scenario but disabling the hipblaslt autotuning by doing export XLA_FLAGS="--xla_gpu_autotune_level=0"
and report if it still fails or not?
rocm-smi
detects as 8 AMD INSTINCT MI200
devices.
2b. I'm running with Rocm 6.0.0 because that's what the IT technicians provide on the machine (I cannot use virtualisation)export XLA_FLAGS="--xla_gpu_autotune_level=0"
works correctly with no failure.Was able to replicate this on an MI300 system.
python3 matmul.py
ROCm path: /opt/rocm-6.0.0/lib
jax_enable_x64: True
jax version: 0.4.31
jaxlib version: 0.4.31
jax: 0.4.31
jaxlib: 0.4.31
numpy: 2.1.1
python: 3.11.10 (main, Sep 8 2024, 14:18:29) [GCC 12.2.1 20221121 (Red Hat 12.2.1-7)]
jax.devices (8 total, 8 local): [RocmDevice(id=0) RocmDevice(id=1) ... RocmDevice(id=6) RocmDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.15.0-119-generic', version='#129-Ubuntu SMP Fri Aug 2 19:25:20 UTC 2024', machine='x86_64')
Traceback (most recent call last):
File "/jax-buildr/matmul.py", line 38, in <module>
c = jax.numpy.matmul(b, a)
^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch ROCm kernel: redzone_checker with block dimensions: 1024x1x1: hipError_t(303)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Don't mind the ubuntu in 'uname'. I had ran it in an almalinux8 container on a ubuntu host.
Going to do more testing tomorrow to see if its only ROCm 6.0.x
I am not able to reproduce it on ROCm 6.0 on gfx90a platform (AMD INSTINCT MI200 ).
docker: rocm/jax:rocm6.0.0-jax0.4.26-py3.11.0
Cloned Jax/xla
git clone -b rocm-jaxlib-v0.4.31 https://github.com/ROCm/jax.git git clone -b rocm-jaxlib-v0.4.31 https://github.com/ROCm/xla.git
build/Install JAX locally using command below
rm -rf dist; python3 -m pip uninstall jax jaxlib jax-rocm60-pjrt jax-rocm60-plugin -y; python3 ./build/build.py --use_clang=false --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_amdgpu_targets=gfx90a --bazel_options=--override_repository=xla=/workspaces/jax_xla/rocm-jaxlib-v0.4.31/xla --rocm_path=/opt/rocm-6.0.0/ && python3 setup.py develop --user && python3 -m pip install dist/*.whl
Installed JAX
jax 0.4.31.dev20240808+a96cefdc0 /workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax jax-rocm60-pjrt 0.4.31.dev20240913 jax-rocm60-plugin 0.4.31.dev20240913 jaxlib 0.4.31.dev20240913
import jax import jax.numpy as jnp jax.config.update("jax_enable_x64", True) x= jax.numpy.ones((512, 25)) M = jax.numpy.ones((512, 512)) z=jnp.matmul(M, x) print(z) [[512. 512. 512. ... 512. 512. 512.] [512. 512. 512. ... 512. 512. 512.] [512. 512. 512. ... 512. 512. 512.] ... [512. 512. 512. ... 512. 512. 512.] [512. 512. 512. ... 512. 512. 512.] [512. 512. 512. ... 512. 512. 512.]]
Can I give you anything to help you identify the issue? Is it enough that it shows up on MI300?
A side note: on HPC systems we cannot use virtualisation/docker (and we don't have super-user privileges). So I cannot guarantee that my setup is equivalent to yours. I load a centrally installed ROCm 6.0.0 library, so I can give you any detail you might want about it. But Docker is intentionally blocked.
I was also able to reproduce this problem on MI300 system with ROCM6.0.0 container (while it runs fine under ROCM 6.0.2 and above). Actually, on rocm-jaxlib-v0.4.31 XLA/JAX branch, it also fails without setting jax_enable_x64 flag when one disables triton_gemm with: --xla_gpu_enable_triton_gemm=false.
We are investigating the issue. As a workaround, one can still use autotuning with redzone checker disabled via: --xla_gpu_autotune_level=3.
Description
I'm running master jax/lib custom-built for RoCm following the instructions online (because there are no such wheels available around). However I'm relatively sure this thing is not because of my custom wheels, but it's an issue within XLA/Jax
The MWE is
and it crashes at the multiplication with the following error. Do note that the script works fine if I do not enable x64 mode.
System info (python version, jaxlib version, accelerator, etc.)