jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

[ROCM] x64 mode crashes with "redzone_checker with block dimensions: 1024x1x1: hipError_t" #23506

Open PhilipVinc opened 2 months ago

PhilipVinc commented 2 months ago

Description

I'm running master jax/lib custom-built for RoCm following the instructions online (because there are no such wheels available around). However I'm relatively sure this thing is not because of my custom wheels, but it's an issue within XLA/Jax

The MWE is

import jax
jax.config.update("jax_enable_x64", True)

x= jax.numpy.ones((512, 25))
M = jax.numpy.ones((512, 512))
M@x

and it crashes at the multiplication with the following error. Do note that the script works fine if I do not enable x64 mode.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py", line 272, in deferring_binary_op
    return binary_op(*args)
           ^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/pjit.py", line 332, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/pjit.py", line 190, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **p.params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/core.py", line 2739, in bind
    return self.bind_with_trace(top_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/core.py", line 433, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/core.py", line 939, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1730, in _pjit_call_impl
    return xc._xla.pjit(
           ^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1712, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
                         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/pjit.py", line 1642, in _pjit_call_impl_python
    ).compile(compile_options)
      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2295, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2807, in from_hlo
    xla_executable = _cached_compilation(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2621, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/compiler.py", line 399, in compile_or_get_cached
    return _compile_and_write_cache(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/compiler.py", line 627, in _compile_and_write_cache
    executable = backend_compile(
                 ^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
    return backend.compile(built_c, compile_options=options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch ROCm kernel: redzone_checker with block dimensions: 1024x1x1: hipError_t(303)

System info (python version, jaxlib version, accelerator, etc.)

>>> jax.
KeyboardInterrupt
>>> import jax; jax.print_environment_info()
jax:    0.4.31
jaxlib: 0.4.31.dev20240909
numpy:  2.0.2
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0]
jax.devices (8 total, 8 local): [RocmDevice(id=0) RocmDevice(id=1) ... RocmDevice(id=6) RocmDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='g1176', release='4.18.0-477.10.1.el8_8.x86_64', version='#1 SMP Wed Apr 5 13:35:01 EDT 2023', machine='x86_64')
mrodden commented 2 months ago

These are typically issues with the hipblaslt autotuning not being able to run its autotuning kernels correctly. I have been seeing a lot of these lately and will be opening an issue to the XLA folks to look into them more.

If you could gather some information for me to forward on to them that would be a big help.

1) Can you confirm that the issue is only present with x64 flag set? 2) What card model were you running this against? (MI100 / MI250 / etc) 3) Can you try running the same scenario but disabling the hipblaslt autotuning by doing export XLA_FLAGS="--xla_gpu_autotune_level=0" and report if it still fails or not?

PhilipVinc commented 2 months ago
  1. Just tried again, and I confirm
  2. It's an HPC node with 4x MI250X accelerators, which rocm-smi detects as 8 AMD INSTINCT MI200 devices. 2b. I'm running with Rocm 6.0.0 because that's what the IT technicians provide on the machine (I cannot use virtualisation)
  3. Running the script above with export XLA_FLAGS="--xla_gpu_autotune_level=0" works correctly with no failure.
mrodden commented 2 months ago

Was able to replicate this on an MI300 system.

python3 matmul.py
ROCm path: /opt/rocm-6.0.0/lib
jax_enable_x64: True
jax version: 0.4.31
jaxlib version: 0.4.31
jax:    0.4.31
jaxlib: 0.4.31
numpy:  2.1.1
python: 3.11.10 (main, Sep  8 2024, 14:18:29) [GCC 12.2.1 20221121 (Red Hat 12.2.1-7)]
jax.devices (8 total, 8 local): [RocmDevice(id=0) RocmDevice(id=1) ... RocmDevice(id=6) RocmDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.15.0-119-generic', version='#129-Ubuntu SMP Fri Aug 2 19:25:20 UTC 2024', machine='x86_64')

Traceback (most recent call last):
  File "/jax-buildr/matmul.py", line 38, in <module>
    c = jax.numpy.matmul(b, a)
        ^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch ROCm kernel: redzone_checker with block dimensions: 1024x1x1: hipError_t(303)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Don't mind the ubuntu in 'uname'. I had ran it in an almalinux8 container on a ubuntu host.

Going to do more testing tomorrow to see if its only ROCm 6.0.x

zahiqbal commented 2 months ago

I am not able to reproduce it on ROCm 6.0 on gfx90a platform (AMD INSTINCT MI200 ).

docker: rocm/jax:rocm6.0.0-jax0.4.26-py3.11.0

Cloned Jax/xla

git clone -b rocm-jaxlib-v0.4.31 https://github.com/ROCm/jax.git git clone -b rocm-jaxlib-v0.4.31 https://github.com/ROCm/xla.git

build/Install JAX locally using command below

rm -rf dist; python3 -m pip uninstall jax jaxlib jax-rocm60-pjrt jax-rocm60-plugin -y; python3 ./build/build.py --use_clang=false --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_amdgpu_targets=gfx90a --bazel_options=--override_repository=xla=/workspaces/jax_xla/rocm-jaxlib-v0.4.31/xla --rocm_path=/opt/rocm-6.0.0/ && python3 setup.py develop --user && python3 -m pip install dist/*.whl

Installed JAX

jax 0.4.31.dev20240808+a96cefdc0 /workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax jax-rocm60-pjrt 0.4.31.dev20240913 jax-rocm60-plugin 0.4.31.dev20240913 jaxlib 0.4.31.dev20240913

import jax import jax.numpy as jnp jax.config.update("jax_enable_x64", True) x= jax.numpy.ones((512, 25)) M = jax.numpy.ones((512, 512)) z=jnp.matmul(M, x) print(z) [[512. 512. 512. ... 512. 512. 512.] [512. 512. 512. ... 512. 512. 512.] [512. 512. 512. ... 512. 512. 512.] ... [512. 512. 512. ... 512. 512. 512.] [512. 512. 512. ... 512. 512. 512.] [512. 512. 512. ... 512. 512. 512.]]

PhilipVinc commented 2 months ago

Can I give you anything to help you identify the issue? Is it enough that it shows up on MI300?

A side note: on HPC systems we cannot use virtualisation/docker (and we don't have super-user privileges). So I cannot guarantee that my setup is equivalent to yours. I load a centrally installed ROCm 6.0.0 library, so I can give you any detail you might want about it. But Docker is intentionally blocked.

pemeliya commented 1 month ago

I was also able to reproduce this problem on MI300 system with ROCM6.0.0 container (while it runs fine under ROCM 6.0.2 and above). Actually, on rocm-jaxlib-v0.4.31 XLA/JAX branch, it also fails without setting jax_enable_x64 flag when one disables triton_gemm with: --xla_gpu_enable_triton_gemm=false.

We are investigating the issue. As a workaround, one can still use autotuning with redzone checker disabled via: --xla_gpu_autotune_level=3.