Hi,
I intended to use the Multi-GPU capabilities of JAX to work with sharded arrays. However, there are two issues:
1) Using the array.at[].set() syntax is lowered to an all-gather operation, which leads to OOM errors since my arrays are too large to be stored on a single device. Technically, it should not be necessary to perform an all-gather operation on scatter, because the updates can be applied independently on each device.
2) As a workaround, i tried to implement the update using jax.lax.scatter in hopes that this is not lowered to all-gather (which actually works). However, I encountered a second bug. Using jax.lax.scatter produces a different result when used with sharded arrays in contrast to normal non-sharded arrays.
Here is an MWE:
import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4 --xla_gpu_force_compilation_parallelism=4'
import jax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
import jax.numpy as jnp
def main():
print(f"{jax.devices()=}")
print(f"{jax.device_count()=}")
sharding = PositionalSharding(
devices=mesh_utils.create_device_mesh(
(4, 1),
devices=jax.devices(),
)
)
print(f"{sharding=}")
# does not work with sharded matrix
arr = jnp.zeros(
(8, 2),
dtype=jnp.float32,
device=sharding,
)
arr = jax.device_put(arr, sharding)
# this would work
# arr = jnp.zeros((8, 2), dtype=jnp.float32)
s = slice(0, 3)
def fn(x):
b = s.indices(x.shape[0])
indices = jnp.arange(b[0], b[1], b[2])[:, None]
updates = jnp.ones((3, x.shape[1]), dtype=jnp.float32)
res = jax.lax.scatter(
x,
indices,
updates,
jax.lax.ScatterDimensionNumbers(
update_window_dims=(1,),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,)
)
)
return res
jit_fn = (
jax.jit(fn)
.lower(arr)
.compile()
)
# print("compiled")
# print(jit_fn.as_text())
r = jit_fn(arr)
print(f"{r=}")
print(f"{r.sum()=}")
if __name__ == '__main__':
main()
When executed with a sharded array, only the first index of the array is set to one, but not the second and third index as one would expect:
To see the all-gather behavior of the normal array.at.set() syntax, uncomment the print of the jitted HLO code. With scatter there is no all-gather, but the result is incorrect.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.31
jaxlib: 0.4.31
numpy: 1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='URL REDACTED FOR PRIVACY', release='5.15.0-113-generic', version='#123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024', machine='x86_64')
$ nvidia-smi
Wed Aug 14 10:36:51 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 On | 00000000:2C:00.0 Off | Off |
| 0% 60C P0 47W / 450W | 396MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 3293702 C python 386MiB |
+-----------------------------------------------------------------------------------------+
Description
Hi, I intended to use the Multi-GPU capabilities of JAX to work with sharded arrays. However, there are two issues:
1) Using the array.at[].set() syntax is lowered to an all-gather operation, which leads to OOM errors since my arrays are too large to be stored on a single device. Technically, it should not be necessary to perform an all-gather operation on scatter, because the updates can be applied independently on each device.
2) As a workaround, i tried to implement the update using jax.lax.scatter in hopes that this is not lowered to all-gather (which actually works). However, I encountered a second bug. Using jax.lax.scatter produces a different result when used with sharded arrays in contrast to normal non-sharded arrays.
Here is an MWE:
When executed with a sharded array, only the first index of the array is set to one, but not the second and third index as one would expect:
Expected Output (Can be achieved by using a non-sharded array):
To see the all-gather behavior of the normal array.at.set() syntax, uncomment the print of the jitted HLO code. With scatter there is no all-gather, but the result is incorrect.
System info (python version, jaxlib version, accelerator, etc.)