jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.09k stars 2.75k forks source link

Scatter on Sharded Matrices has bugs #23052

Open ymahlau opened 1 month ago

ymahlau commented 1 month ago

Description

Hi, I intended to use the Multi-GPU capabilities of JAX to work with sharded arrays. However, there are two issues:

1) Using the array.at[].set() syntax is lowered to an all-gather operation, which leads to OOM errors since my arrays are too large to be stored on a single device. Technically, it should not be necessary to perform an all-gather operation on scatter, because the updates can be applied independently on each device.

2) As a workaround, i tried to implement the update using jax.lax.scatter in hopes that this is not lowered to all-gather (which actually works). However, I encountered a second bug. Using jax.lax.scatter produces a different result when used with sharded arrays in contrast to normal non-sharded arrays.

Here is an MWE:

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4 --xla_gpu_force_compilation_parallelism=4'

import jax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
import jax.numpy as jnp

def main():

    print(f"{jax.devices()=}")
    print(f"{jax.device_count()=}")

    sharding = PositionalSharding(
        devices=mesh_utils.create_device_mesh(
            (4, 1),
            devices=jax.devices(),
        )
    )
    print(f"{sharding=}")

    # does not work with sharded matrix
    arr = jnp.zeros(
        (8, 2),
        dtype=jnp.float32,
        device=sharding,
    )
    arr = jax.device_put(arr, sharding)

    # this would work
    # arr = jnp.zeros((8, 2), dtype=jnp.float32)

    s = slice(0, 3)
    def fn(x):        
        b = s.indices(x.shape[0])
        indices = jnp.arange(b[0], b[1], b[2])[:, None]

        updates = jnp.ones((3, x.shape[1]), dtype=jnp.float32)

        res = jax.lax.scatter(
            x,
            indices,
            updates,
            jax.lax.ScatterDimensionNumbers(
                update_window_dims=(1,),
                inserted_window_dims=(0,),
                scatter_dims_to_operand_dims=(0,)
            )
        )
        return res

    jit_fn = (
        jax.jit(fn)
        .lower(arr)
        .compile()
    )

    # print("compiled")
    # print(jit_fn.as_text())

    r = jit_fn(arr)
    print(f"{r=}")
    print(f"{r.sum()=}")

if __name__ == '__main__':
    main()

When executed with a sharded array, only the first index of the array is set to one, but not the second and third index as one would expect:

jax.devices()=[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]
jax.device_count()=4
sharding=PositionalSharding([[{CPU 0}]
                    [{CPU 1}]
                    [{CPU 2}]
                    [{CPU 3}]], shape=(4, 1))
r=Array([[1., 1.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)
r.sum()=Array(2., dtype=float32)

Expected Output (Can be achieved by using a non-sharded array):

r=Array([[1., 1.],
       [1., 1.],
       [1., 1.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)
r.sum()=Array(6., dtype=float32)

To see the all-gather behavior of the normal array.at.set() syntax, uncomment the print of the jitted HLO code. With scatter there is no all-gather, but the result is incorrect.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='URL REDACTED FOR PRIVACY', release='5.15.0-113-generic', version='#123-Ubuntu SMP Mon Jun 10 08:16:17 UTC 2024', machine='x86_64')

$ nvidia-smi
Wed Aug 14 10:36:51 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06              Driver Version: 555.42.06      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        On  |   00000000:2C:00.0 Off |                  Off |
|  0%   60C    P0             47W /  450W |     396MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   3293702      C   python                                        386MiB |
+-----------------------------------------------------------------------------------------+
ymahlau commented 1 month ago

The problematic all-gather operation is also described in issue #20381