Open PhilipVinc opened 2 months ago
I am not able to reproduce like it is reported (a seg fault). I aligned data to the number of devices and it seems that sharing was done ok. However if data dimension is not divisible by devices then it errored out
Here are the details of my experiments.
System Info:
>>> import jax; jax.print_environment_info()
jax: 0.4.31.dev20240808+a96cefdc0
jaxlib: 0.4.31.dev20240913
numpy: 2.1.1
python: 3.11.0 (main, Apr 9 2024, 03:49:51) [GCC 9.4.0]
jax.devices (16 total, 16 local): [RocmDevice(id=0) RocmDevice(id=1) ... RocmDevice(id=14) RocmDevice(id=15)]
process_count: 1
platform: uname_result(system='Linux', node='hyd-7c-ZT13-03', release='5.15.0-91-generic', version='#101~20.04.1-Ubuntu SMP Thu Nov 16 14:22:28 UTC 2023', machine='x86_64')
Experiment#1: aligned the sharing data to the number of devices. i.e x is of shape 160x10. it looks like data-sharding was done
python3 multi_device_datasharding.py
devices: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7), RocmDevice(id=8), RocmDevice(id=9), RocmDevice(id=10), RocmDevice(id=11), RocmDevice(id=12), RocmDevice(id=13), RocmDevice(id=14), RocmDevice(id=15)]
local: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7), RocmDevice(id=8), RocmDevice(id=9), RocmDevice(id=10), RocmDevice(id=11), RocmDevice(id=12), RocmDevice(id=13), RocmDevice(id=14), RocmDevice(id=15)]
PositionalSharding([[{GPU 0}]
[{GPU 1}]
[{GPU 2}]
[{GPU 3}]
[{GPU 4}]
[{GPU 5}]
[{GPU 6}]
[{GPU 7}]
[{GPU 8}]
[{GPU 9}]
[{GPU 10}]
[{GPU 11}]
[{GPU 12}]
[{GPU 13}]
[{GPU 14}]
[{GPU 15}]], shape=(16, 1))
Experiment#2: x is 120x10 matrix, Errored with message. the global size of data dimension 0 should be divisible by num of devices (16),
root@hyd-7c-ZT13-03:/workspaces# python3 multi_device_reduction.py
devices: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7), RocmDevice(id=8), RocmDevice(id=9), RocmDevice(id=10), RocmDevice(id=11), RocmDevice(id=12), RocmDevice(id=13), RocmDevice(id=14), RocmDevice(id=15)]
local: [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3), RocmDevice(id=4), RocmDevice(id=5), RocmDevice(id=6), RocmDevice(id=7), RocmDevice(id=8), RocmDevice(id=9), RocmDevice(id=10), RocmDevice(id=11), RocmDevice(id=12), RocmDevice(id=13), RocmDevice(id=14), RocmDevice(id=15)]
Traceback (most recent call last):
File "/workspaces/multi_device_reduction.py", line 7, in <module>
y=jax.lax.with_sharding_constraint(x, jax.sharding.PositionalSharding(jax.devices()).reshape(-1, 1))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/pjit.py", line 2465, in with_sharding_constraint
outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/pjit.py", line 2465, in <listcomp>
outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/core.py", line 429, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/core.py", line 433, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/core.py", line 939, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/workspaces/jax_xla/rocm-jaxlib-v0.4.31/jax/jax/_src/pjit.py", line 2480, in _sharding_constraint_impl
return api.jit(_identity_fn, out_shardings=sharding)(x)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: One of pjit outputs was given the sharding of PositionalSharding([[{GPU 0}]
[{GPU 1}]
[{GPU 2}]
[{GPU 3}]
[{GPU 4}]
[{GPU 5}]
[{GPU 6}]
[{GPU 7}]
[{GPU 8}]
[{GPU 9}]
[{GPU 10}]
[{GPU 11}]
[{GPU 12}]
[{GPU 13}]
[{GPU 14}]
[{GPU 15}]], shape=(16, 1)), which implies that the global size of its dimension 0 should be divisible by 16, but it is equal to 120 (full shape: (120, 10))
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Thank you. I had 8 devices, and the given size (120) is divisible by 8.
It might be some problems in the setting of our HPC ROCM libraries.
Is there some way to debug what is happening inside of
#0 0x000015555204c490 in hip::FatBinaryInfo::BuildProgram(int) () from /opt/rocm-6.0.0/lib/libamdhip64.so.6
?
@PhilipVinc Can you please mentioned how did you get Jax 4,31 for rocm or steps to build jax locally.
I am using following steps build Jax in docker image
docker: rocm/jax:rocm6.0.0-jax0.4.26-py3.11.0
Cloned Jax/xla git clone -b rocm-jaxlib-v0.4.31 https://github.com/ROCm/jax.git git clone -b rocm-jaxlib-v0.4.31 https://github.com/ROCm/xla.git
build/Install JAX locally using command below
rm -rf dist; python3 -m pip uninstall jax jaxlib jax-rocm60-pjrt jax-rocm60-plugin -y; python3 ./build/build.py --use_clang=false --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_amdgpu_targets=gfx90a --bazel_options=--override_repository=xla=/workspaces/jax_xla/rocm-jaxlib-v0.4.31/xla --rocm_path=/opt/rocm-6.0.0/ && python3 setup.py develop --user && python3 -m pip install dist/*.whl
Description
Running a simple program with a global reduction causes a segfault.
The MWE
I'm running on a single node of CINES Ad Astra HPC that has 4xMI250X (seen as 8xMI200). I'm using ROCM 6.0 and a custom built version of jax because there are no wheels publicly available.
Running the program above leads to no information Segfault
so I ran it under gdb to get a stack trace
If you need anything else let me know
System info (python version, jaxlib version, accelerator, etc.)