google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.95k stars 2.75k forks source link

`.at.set` notation assigns the wrong values to the wrong array elements on a 2x2 TPU #22827

Open Sohl-Dickstein opened 1 month ago

Sohl-Dickstein commented 1 month ago

Description

When I run the following code block, it produces different output with and without sharding:

import jax
import jax.numpy as jnp
import numpy as np
mesh = jax.sharding.Mesh(np.array(jax.devices()), ('x', ))
spec = jax.sharding.PartitionSpec('x',)
# A = jnp.ones((4,32)) # if we use this shape, then the sharded output is correct
A = jnp.ones((4,1024))
print("without sharding:")
print(A.at[::2].set(0))
with mesh:
  B = jax.lax.with_sharding_constraint(A, spec)
print("with sharding:")
print(B.at[::2].set(0))

The output of this code block is:

without sharding:
[[0. 0. 0. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]
 [0. 0. 0. ... 0. 0. 0.]
 [1. 1. 1. ... 1. 1. 1.]]
with sharding:
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]

I observed similar problems with different shaped array on different TPU topologies, but as can be seen from the commented out line the problem is array shape dependent.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31.dev20240628
jaxlib: 0.4.31.dev20240628
numpy:  1.24.4
python: 3.11.6 (main, Jul 27 2024, 03:09:39) [GCC 11.4.0]
jax.devices (4 total, 4 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0) TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='jascha2x2-0', release='6.1.85+', version='#1 SMP PREEMPT_DYNAMIC Sat Jun  1 14:27:51 UTC 2024', machine='x86_64')
froystig commented 1 month ago

@Sohl-Dickstein! Thanks for filing.

I believe I'm seeing things working at the latest stable release point (jax+jaxlib 0.4.31, released about three days ago). Do you still see the issue if you upgrade?