jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.96k stars 2.75k forks source link

[ROCM] Multi-device reduction causes segfault #23565

Open PhilipVinc opened 1 week ago

PhilipVinc commented 1 week ago

Description

Running a simple program with a global reduction causes a segfault.

The MWE

import jax
print("devices:", jax.devices())
print("local:", jax.local_devices())

x=jax.numpy.ones((120, 10))
y=jax.lax.with_sharding_constraint(x, jax.sharding.PositionalSharding(jax.devices()).reshape(-1, 1))
print(y.sharding)
y.sum()

I'm running on a single node of CINES Ad Astra HPC that has 4xMI250X (seen as 8xMI200). I'm using ROCM 6.0 and a custom built version of jax because there are no wheels publicly available.

Running the program above leads to no information Segfault

ModuleNotFoundError: No module named 'jax'
(myjax) [cad14908] fvicentini@g1265:~/prove$ /opt/software/gaia-external/IA/jax/0.4.31/miniconda3/envs/myjax/bin/python crash.py
devices: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7)]
local: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7)]
PositionalSharding([[{GPU 0}]
                    [{GPU 1}]
                    [{GPU 2}]
                    [{GPU 3}]
                    [{GPU 4}]
                    [{GPU 5}]
                    [{GPU 6}]
                    [{GPU 7}]], shape=(8, 1))
Segmentation fault (core dumped)
(myjax) [cad14908] fvicentini@g1265:~/prove$

so I ran it under gdb to get a stack trace

                    [{GPU 7}]], shape=(8, 1))
[Detaching after vfork from child process 1287214]
[New Thread 0x14e23afff700 (LWP 1287227)]
[New Thread 0x14e23adfe700 (LWP 1287228)]
[New Thread 0x14e23abfd700 (LWP 1287229)]
[New Thread 0x14e23a9fc700 (LWP 1287230)]
[New Thread 0x14e23a7fb700 (LWP 1287231)]

Thread 480 "python" received signal SIGSEGV, Segmentation fault.
[Switching to Thread 0x14e23abfd700 (LWP 1287229)]
0x000015555204c490 in hip::FatBinaryInfo::BuildProgram(int) () from /opt/rocm-6.0.0/lib/libamdhip64.so.6
(gdb) bt
#0  0x000015555204c490 in hip::FatBinaryInfo::BuildProgram(int) () from /opt/rocm-6.0.0/lib/libamdhip64.so.6
#1  0x000015555204fd2e in hip::Function::getStatFuncAttr(hipFuncAttributes*, int) () from /opt/rocm-6.0.0/lib/libamdhip64.so.6
#2  0x000015555200d6a7 in hip::StatCO::getStatFuncAttr(hipFuncAttributes*, void const*, int) () from /opt/rocm-6.0.0/lib/libamdhip64.so.6
#3  0x0000155552140341 in hipFuncGetAttributes () from /opt/rocm-6.0.0/lib/libamdhip64.so.6
#4  0x000015545264627f in ncclInitKernelsForDevice(int, unsigned long*) () from /opt/rocm-6.0.0/lib/librccl.so.1
#5  0x000015545267c298 in ncclCommInitRankFunc(ncclAsyncJob*) () from /opt/rocm-6.0.0/lib/librccl.so.1
#6  0x0000155452674f37 in ncclAsyncJobMain(void*) () from /opt/rocm-6.0.0/lib/librccl.so.1
#7  0x000015555510d1ca in start_thread () from /lib64/libpthread.so.0
#8  0x00001555545efe73 in clone () from /lib64/libc.so.6
(gdb) quit

If you need anything else let me know

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.31
jaxlib: 0.4.31.dev20240909
numpy:  2.0.2
python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:53:32) [GCC 12.3.0]
jax.devices (8 total, 8 local): [RocmDevice(id=0) RocmDevice(id=1) ... RocmDevice(id=6) RocmDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='g1265', release='4.18.0-477.10.1.el8_8.x86_64', version='#1 SMP Wed Apr 5 13:35:01 EDT 2023', machine='x86_64')
zahiqbal commented 6 days ago

I am not able to reproduce like it is reported (a seg fault). I aligned data to the number of devices and it seems that sharing was done ok. However if data dimension is not divisible by devices then it errored out

Here are the details of my experiments.

System Info:

>>> import jax; jax.print_environment_info()
jax:    0.4.31.dev20240808+a96cefdc0
jaxlib: 0.4.31.dev20240913
numpy:  2.1.1
python: 3.11.0 (main, Apr  9 2024, 03:49:51) [GCC 9.4.0]
jax.devices (16 total, 16 local): [RocmDevice(id=0) RocmDevice(id=1) ... RocmDevice(id=14) RocmDevice(id=15)]
process_count: 1
platform: uname_result(system='Linux', node='hyd-7c-ZT13-03', release='5.15.0-91-generic', version='#101~20.04.1-Ubuntu SMP Thu Nov 16 14:22:28 UTC 2023', machine='x86_64')

Experiment#1: aligned the sharing data to the number of devices. i.e x is of shape 160x10. it looks like data-sharding was done

python3 multi_device_datasharding.py
devices: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7), RocmDevice(id=8), RocmDevice(id=9), RocmDevice(id=10), RocmDevice(id=11), RocmDevice(id=12), RocmDevice(id=13), RocmDevice(id=14), RocmDevice(id=15)]
local: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7), RocmDevice(id=8), RocmDevice(id=9), RocmDevice(id=10), RocmDevice(id=11), RocmDevice(id=12), RocmDevice(id=13), RocmDevice(id=14), RocmDevice(id=15)]
PositionalSharding([[{GPU 0}]
                    [{GPU 1}]
                    [{GPU 2}]
                    [{GPU 3}]
                    [{GPU 4}]
                    [{GPU 5}]
                    [{GPU 6}]
                    [{GPU 7}]
                    [{GPU 8}]
                    [{GPU 9}]
                    [{GPU 10}]
                    [{GPU 11}]
                    [{GPU 12}]
                    [{GPU 13}]
                    [{GPU 14}]
                    [{GPU 15}]], shape=(16, 1))

Experiment#2: x is 120x10 matrix, Errored with message. the global size of data dimension 0 should be divisible by num of devices (16),

root@hyd-7c-ZT13-03:/workspaces# python3 multi_device_reduction.py
devices: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7), RocmDevice(id=8), RocmDevice(id=9), RocmDevice(id=10), RocmDevice(id=11), RocmDevice(id=12), RocmDevice(id=13), RocmDevice(id=14), RocmDevice(id=15)]
local: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7), RocmDevice(id=8), RocmDevice(id=9), RocmDevice(id=10), RocmDevice(id=11), RocmDevice(id=12), RocmDevice(id=13), RocmDevice(id=14), RocmDevice(id=15)]
Traceback (most recent call last):
  File "/workspaces/multi_device_reduction.py", line 7, in <module>
    y=jax.lax.with_sharding_constraint(x, jax.sharding.PositionalSharding(jax.devices()).reshape(-1, 1))
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/pjit.py", line 2465, in with_sharding_constraint
    outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/pjit.py", line 2465, in <listcomp>
    outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l,
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/core.py", line 429, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/core.py", line 433, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/core.py", line 939, in process_primitive
    return primitive.impl(*tracers, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/pjit.py", line 2480, in _sharding_constraint_impl
    return api.jit(_identity_fn, out_shardings=sharding)(x)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs was given the sharding of PositionalSharding([[{GPU 0}]
                    [{GPU 1}]
                    [{GPU 2}]
                    [{GPU 3}]
                    [{GPU 4}]
                    [{GPU 5}]
                    [{GPU 6}]
                    [{GPU 7}]
                    [{GPU 8}]
                    [{GPU 9}]
                    [{GPU 10}]
                    [{GPU 11}]
                    [{GPU 12}]
                    [{GPU 13}]
                    [{GPU 14}]
                    [{GPU 15}]], shape=(16, 1)), which implies that the global size of its dimension 0 should be divisible by 16, but it is equal to 120 (full shape: (120, 10))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
PhilipVinc commented 6 days ago

Thank you. I had 8 devices, and the given size (120) is divisible by 8.

It might be some problems in the setting of our HPC ROCM libraries.

Is there some way to debug what is happening inside of

#0  0x000015555204c490 in hip::FatBinaryInfo::BuildProgram(int) () from /opt/rocm-6.0.0/lib/libamdhip64.so.6

?

zahiqbal commented 6 days ago

@PhilipVinc Can you please mentioned how did you get Jax 4,31 for rocm or steps to build jax locally.

I am using following steps build Jax in docker image

docker: rocm/jax:rocm6.0.0-jax0.4.26-py3.11.0

Cloned Jax/xla git clone -b rocm-jaxlib-v0.4.31 https://github.com/ROCm/jax.git git clone -b rocm-jaxlib-v0.4.31 https://github.com/ROCm/xla.git

build/Install JAX locally using command below

rm -rf dist; python3 -m pip uninstall jax jaxlib jax-rocm60-pjrt jax-rocm60-plugin -y; python3 ./build/build.py --use_clang=false --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_amdgpu_targets=gfx90a --bazel_options=--override_repository=xla=/workspaces/jax_xla/rocm-jaxlib-v0.4.31/xla --rocm_path=/opt/rocm-6.0.0/ && python3 setup.py develop --user && python3 -m pip install dist/*.whl