jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30k stars 2.75k forks source link

jax.numpy.nonzero extremely slow on sharded array #23675

Open muchanem opened 6 days ago

muchanem commented 6 days ago

Description

When jnp.nonzero (or jnp.argwhere or sparse.BCOO.fromdense) is run on a sparse array and that array is sharded (this doesn't happen without sharding), the nonzero operation is extremely slow. This is a little odd to me since there shouldn't be much (or any) communication between devices for this op since it can be trivially split entirely between the devices. I'm actually unsure if the op ever completes. I've let it run for 10+ mins with no result.

import jax
import jax.numpy as jnp
import numpy as np

arr = jnp.zeros((2**14, 2**14))
# generate along axis 0 in a "ragged" way
axis_0_idxs = np.concatenate([(np.zeros((np.random.randint(16,64),),dtype=np.int64) + i) for i in range(2**14)])
axis_1_idxs = np.random.randint(0,2**14-1,(axis_0_idxs.shape[0]))
arr = arr.at[axis_0_idxs, axis_1_idxs].set(1E-2)

from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
sharding = PositionalSharding(mesh_utils.create_device_mesh((jax.local_device_count(),1)))
arr = jax.device_put(arr, sharding)

# takes a few seconds
np.nonzero(np.asarray(arr))

# extremely slow
jnp.nonzero(arr)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.11.9 (main, Apr 21 2024, 09:01:17) [GCC 9.4.0]
jax.devices (2 total, 2 local): [CudaDevice(id=0) CudaDevice(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='m001', release='5.4.0-190-generic', version='#210-Ubuntu SMP Fri Jul 5 17:03:38 UTC 2024', machine='x86_64')

$ nvidia-smi
Mon Sep 16 16:49:16 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.28.03              Driver Version: 560.28.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 NVL                On  |   00000000:2A:00.0 Off |                    0 |
| N/A   42C    P0             90W /  400W |     538MiB /  95830MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 NVL                On  |   00000000:3D:00.0 Off |                    0 |
| N/A   47C    P0             94W /  400W |     538MiB /  95830MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   2761011      C   ...ons/3.11.9/envs/nnmf-jax/bin/python        526MiB |
|    1   N/A  N/A   2761011      C   ...ons/3.11.9/envs/nnmf-jax/bin/python        526MiB |
+-----------------------------------------------------------------------------------------+
muchanem commented 5 days ago

Seems related to the size calculation which would explain a bunch of communication overhead making the op slow, JITing it with a specified size speeds up computation significantly (it is still much much slower than numpy on CPU)

yashk2810 commented 4 days ago

After a warmup loop, it's pretty quick: (tested on a TPU)

image