Open altimofeev opened 10 months ago
Hi @altimofeev
The issue with top_k
returning different first values for different k
values has been resolved in JAX 0.4.34. I have verified this on a cloud VM with 4xT4 GPUs. The top_k
function now consistently returns the same first value for both k=1
and k=4
. Please see the attached screenshot for confirmation.
Thank you.
Description
It seems that
jax.lax.top_k
behaves incorrectly when input is sharded across dimension used for computing top-k.In particular, the following code shows that the elements of top-1 and top-4 are inconsistent with each other.
Prints:
Which is incorrect, since top-1 can't be less than top-1 of top-4. Note, the code works for input shape of
[1, 8*1024]
.Tested on
NCCL version 2.19.3+cuda12.3
,jax 0.4.23
on a mesh of 8 GPUs.What jax/jaxlib version are you using?
jax 0.4.23, jaxlib 0.4.23
Which accelerator(s) are you using?
GPU x 8
Additional system info?
uname_result(system='Linux', release='5.4.253-167.359.amzn2.x86_64', version='#1 SMP Tue Aug 15 21:40:23 UTC 2023', machine='x86_64')
NVIDIA GPU info