google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.28k stars 2.68k forks source link

On jax-metal, updating multidimensional boolean arrays sometimes fails #20675

Open shawwn opened 3 months ago

shawwn commented 3 months ago

Description

I ran into a rather surprising case involving 3D boolean arrays, which only seems to fail on jax-metal.

Correct behavior (CPU):

>>> jnp.zeros((2,2,2), dtype=jnp.bool).at[:, :, 0].set(True)[:, :, 0]
Array([[ True,  True],
       [ True,  True]], dtype=bool)

But on jax-metal, I get:

>>> jnp.zeros((2,2,2), dtype=jnp.bool).at[:, :, 0].set(True)[:, :, 0]
Array([[False, False],
       [False, False]], dtype=bool)

After playing around with some inputs, the problem seems to occur for .at[:, i] and .at[:, :, i], but .at[i] works fine. So, any dimension higher than 0 seems to have a bug in the scatter update algorithm for booleans.

Is there some way I can help debug this? Is the jax-metal code open source? If it is, then if you point me to build instructions, I can try to track down the bug.

System info (python version, jaxlib version, accelerator, etc.)

On jax-metal:

>>> import jax; jax.print_environment_info()
jax:    0.4.25
jaxlib: 0.4.23
numpy:  1.26.2
python: 3.10.13 (main, Aug 24 2023, 22:36:46) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='shawn.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')

On CPU (correct behavior):

>>> import jax; jax.print_environment_info()
jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.2
python: 3.10.13 (main, Aug 24 2023, 22:36:46) [Clang 14.0.3 (clang-1403.0.22.14.1)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='shawn.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:49 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6020', machine='arm64')
shawwn commented 3 months ago

(Note that integers and floats seem to work fine; only dtype=jnp.bool_ seems affected.)

shuhand0 commented 3 months ago

jax-metal is not open sourced as of the time. We'll look into the issue and update any change here.