I am uncertain whether this behavior constitutes a feature or a bug. If it is indeed a feature, it ought to be clearly documented within the sharding guidelines. Conversely, if it is identified as a bug, it necessitates immediate rectification. This same behavior has also been observed in functions such as jax.random.uniform(), jax.numpy.linspace(), among others.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 2.0.0
python: 3.12.5 | packaged by Anaconda, Inc. | (main, Sep 12 2024, 18:27:27) [GCC 11.2.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='gpu1', release='5.15.0-107-generic', version='#117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024', machine='x86_64')
$ nvidia-smi
Tue Oct 29 19:25:40 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
| 31% 26C P2 15W / 450W | 18655MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 4090 Off | 00000000:21:00.0 Off | Off |
| 30% 27C P2 20W / 450W | 18649MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA GeForce RTX 4090 Off | 00000000:41:00.0 Off | Off |
| 30% 26C P2 23W / 450W | 18649MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA GeForce RTX 4090 Off | 00000000:61:00.0 Off | Off |
| 31% 25C P2 19W / 450W | 18649MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA GeForce RTX 4090 Off | 00000000:81:00.0 Off | Off |
| 30% 26C P2 15W / 450W | 18649MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA GeForce RTX 4090 Off | 00000000:A1:00.0 Off | Off |
| 30% 26C P2 12W / 450W | 18649MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA GeForce RTX 4090 Off | 00000000:C1:00.0 Off | Off |
| 30% 27C P2 16W / 450W | 18649MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA GeForce RTX 4090 Off | 00000000:E1:00.0 Off | Off |
| 31% 25C P2 20W / 450W | 18649MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 4100405 C ...conda3/envs/main/bin/python 18644MiB |
| 1 N/A N/A 4100405 C ...conda3/envs/main/bin/python 18638MiB |
| 2 N/A N/A 4100405 C ...conda3/envs/main/bin/python 18638MiB |
| 3 N/A N/A 4100405 C ...conda3/envs/main/bin/python 18638MiB |
| 4 N/A N/A 4100405 C ...conda3/envs/main/bin/python 18638MiB |
| 5 N/A N/A 4100405 C ...conda3/envs/main/bin/python 18638MiB |
| 6 N/A N/A 4100405 C ...conda3/envs/main/bin/python 18638MiB |
| 7 N/A N/A 4100405 C ...conda3/envs/main/bin/python 18638MiB |
+-----------------------------------------------------------------------------------------+
Description
We will get:
I am uncertain whether this behavior constitutes a feature or a bug. If it is indeed a feature, it ought to be clearly documented within the sharding guidelines. Conversely, if it is identified as a bug, it necessitates immediate rectification. This same behavior has also been observed in functions such as
jax.random.uniform()
,jax.numpy.linspace()
, among others.System info (python version, jaxlib version, accelerator, etc.)