jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

top_k incorrect results on GPU mesh #19271

Open altimofeev opened 10 months ago

altimofeev commented 10 months ago

Description

It seems that jax.lax.top_k behaves incorrectly when input is sharded across dimension used for computing top-k.

In particular, the following code shows that the elements of top-1 and top-4 are inconsistent with each other.

import numpy as np
import jax
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, PartitionSpec

def fn(x):
    v1, _ = jax.lax.top_k(x, k=1)
    v4, _ = jax.lax.top_k(x, k=4)
    return v1, v4

pfn = pjit(
    fn,
    in_shardings=(PartitionSpec(None, 'data'),),
)

with Mesh(np.array(jax.devices()), ('data',)):
    r1, r4 = pfn(jax.random.normal(jax.random.PRNGKey(0), [2, 8*1024]))
    print('top-1', r1[0])
    print('top-4', r4[0])

Prints:

top-1 [3.671912]
top-4 [3.7298186 3.671912  3.4914806 3.418947 ]

Which is incorrect, since top-1 can't be less than top-1 of top-4. Note, the code works for input shape of [1, 8*1024].

Tested on NCCL version 2.19.3+cuda12.3, jax 0.4.23 on a mesh of 8 GPUs.

What jax/jaxlib version are you using?

jax 0.4.23, jaxlib 0.4.23

Which accelerator(s) are you using?

GPU x 8

Additional system info?

uname_result(system='Linux', release='5.4.253-167.359.amzn2.x86_64', version='#1 SMP Tue Aug 15 21:40:23 UTC 2023', machine='x86_64')

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.182.03   Driver Version: 470.182.03   CUDA Version: 12.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  On   | 00000000:10:1C.0 Off |                    0 |
| N/A   37C    P0    54W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM...  On   | 00000000:10:1D.0 Off |                    0 |
| N/A   34C    P0    57W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A100-SXM...  On   | 00000000:20:1C.0 Off |                    0 |
| N/A   36C    P0    54W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A100-SXM...  On   | 00000000:20:1D.0 Off |                    0 |
| N/A   34C    P0    56W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A100-SXM...  On   | 00000000:90:1C.0 Off |                    0 |
| N/A   37C    P0    57W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A100-SXM...  On   | 00000000:90:1D.0 Off |                    0 |
| N/A   34C    P0    55W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A100-SXM...  On   | 00000000:A0:1C.0 Off |                    0 |
| N/A   37C    P0    59W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A100-SXM...  On   | 00000000:A0:1D.0 Off |                    0 |
| N/A   36C    P0    54W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+
rajasekharporeddy commented 3 weeks ago

Hi @altimofeev

The issue with top_k returning different first values for different k values has been resolved in JAX 0.4.34. I have verified this on a cloud VM with 4xT4 GPUs. The top_k function now consistently returns the same first value for both k=1 and k=4. Please see the attached screenshot for confirmation.

image

Thank you.