google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.03k stars 2.65k forks source link

jax-metal: non-deterministic behavior of `jnp.take_along_axis` #21757

Open jkarwowski opened 3 weeks ago

jkarwowski commented 3 weeks ago

Description

When running the following program twice:

import jax.numpy as jnp

logits = jnp.array([[-0.63927454,  0.0,         -1.2245101],
 [-0.17820558,  0.0,         -0.7477473]],      dtype=jnp.float32)

labels1 = jnp.array([[0, 1]])
labels2 = jnp.array([[1], [0]], dtype=jnp.int32)

print(jnp.take_along_axis(logits, labels1, axis=1))
print(jnp.take_along_axis(logits, labels2, axis=1))

It produces different outputs:

(jax-metal)  ~/example > python3.11 err.py
2024-06-10 06:25:40.053596: W pjrt_plugin/src/mps_client.cc:534] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB

[[-0.63927454  0.        ]
 [-0.17820558  0.        ]]
[[ 0.        ]
 [-0.17820558]]
(jax-metal)  ~/example > python3.11 err.py
2024-06-10 06:25:43.420291: W pjrt_plugin/src/mps_client.cc:534] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB

[[-0.63927454  0.        ]
 [-0.17820558  0.        ]]
[[ 0.]
 [nan]]

Usually has to be reran ~10 times in order to see the behavior change.

This behavior disappears when the program is run without accessing the array with labels1 first:

import jax.numpy as jnp

logits = jnp.array([[-0.63927454,  0.0,         -1.2245101],
 [-0.17820558,  0.0,         -0.7477473]],      dtype=jnp.float32)
labels = jnp.array([[1], [0]], dtype=jnp.int32)
print(jnp.take_along_axis(logits, labels, axis=1))

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.11
jaxlib: 0.4.11
numpy:  1.26.2
python: 3.11.4 (main, Jul 25 2023, 17:36:13) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [MetalDevice(id=0, process_index=0)]
process_count: 1

I'm using MacOS 14.1.1 (23B81) with jax-metal==0.0.4.

jakevdp commented 3 weeks ago

Possible duplicate of #17344.

It looks like that was fixed in jax-metal v0.0.5, so you may be able to fix your issue by upgrading to the most recent release.

kulinseth commented 2 weeks ago

Agree with @jakevdp , can you try latest jax-metal and see if issue exists ?