jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

Functions like `jax.numpy.concatenate` are not compatible with `jax.sharding`, resulting in improper functioning #24576

Open kYangLi opened 4 weeks ago

kYangLi commented 4 weeks ago

Description

import jax
import jax.numpy as jnp

mesh = jax.make_mesh((4, 2), ('data', 'model'))
data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec('data',)
)

arr1 = jax.device_put(jnp.ones((4,)), data_sharding)
arr2 = jax.device_put(jnp.ones((12,)), data_sharding)
arr12 = jnp.concatenate([arr1, arr2])

print(f'Array 1: {arr1}')
print(f'Array 2: {arr2}')
print(f'Array 1&2: {arr12}')

We will get:

Array 1: [1. 1. 1. 1.]
Array 2: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
Array 1&2: [2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]

I am uncertain whether this behavior constitutes a feature or a bug. If it is indeed a feature, it ought to be clearly documented within the sharding guidelines. Conversely, if it is identified as a bug, it necessitates immediate rectification. This same behavior has also been observed in functions such as jax.random.uniform(), jax.numpy.linspace(), among others.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.34
jaxlib: 0.4.34
numpy:  2.0.0
python: 3.12.5 | packaged by Anaconda, Inc. | (main, Sep 12 2024, 18:27:27) [GCC 11.2.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='gpu1', release='5.15.0-107-generic', version='#117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024', machine='x86_64')

$ nvidia-smi
Tue Oct 29 19:25:40 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
| 31%   26C    P2             15W /  450W |   18655MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off |   00000000:21:00.0 Off |                  Off |
| 30%   27C    P2             20W /  450W |   18649MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GeForce RTX 4090        Off |   00000000:41:00.0 Off |                  Off |
| 30%   26C    P2             23W /  450W |   18649MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA GeForce RTX 4090        Off |   00000000:61:00.0 Off |                  Off |
| 31%   25C    P2             19W /  450W |   18649MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA GeForce RTX 4090        Off |   00000000:81:00.0 Off |                  Off |
| 30%   26C    P2             15W /  450W |   18649MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA GeForce RTX 4090        Off |   00000000:A1:00.0 Off |                  Off |
| 30%   26C    P2             12W /  450W |   18649MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA GeForce RTX 4090        Off |   00000000:C1:00.0 Off |                  Off |
| 30%   27C    P2             16W /  450W |   18649MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA GeForce RTX 4090        Off |   00000000:E1:00.0 Off |                  Off |
| 31%   25C    P2             20W /  450W |   18649MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A   4100405      C   ...conda3/envs/main/bin/python              18644MiB |
|    1   N/A  N/A   4100405      C   ...conda3/envs/main/bin/python              18638MiB |
|    2   N/A  N/A   4100405      C   ...conda3/envs/main/bin/python              18638MiB |
|    3   N/A  N/A   4100405      C   ...conda3/envs/main/bin/python              18638MiB |
|    4   N/A  N/A   4100405      C   ...conda3/envs/main/bin/python              18638MiB |
|    5   N/A  N/A   4100405      C   ...conda3/envs/main/bin/python              18638MiB |
|    6   N/A  N/A   4100405      C   ...conda3/envs/main/bin/python              18638MiB |
|    7   N/A  N/A   4100405      C   ...conda3/envs/main/bin/python              18638MiB |
+-----------------------------------------------------------------------------------------+
rajasekharporeddy commented 4 weeks ago

Related to #19106