jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.81k forks source link

Jax is_fully_replecated method issue #15919

Open InbarItayG opened 1 year ago

InbarItayG commented 1 year ago

Description

Seems like the is_fully_replicated method is problematic. See colab: https://colab.corp.google.com/drive/1qIUvfu8OvdqbdEby8DdnzuNYb3dKOkBq?usp=sharing

The code we have there is : _jax_arr= jnp.array([1,2,3]) jax_arr.is_fully_replicated # True

jax_arr = flax.jax_utils.replicate(jax_arr) jax_arr.is_fullyreplicated **# False**

We would image it being the exact opposite outputs but both for a TPU and CPU runtimes we get that mysterious output.

What jax/jaxlib version are you using?

Latest

Which accelerator(s) are you using?

CPUt + TPU

Additional system info

Colab/Borg

NVIDIA GPU info

No response

yashk2810 commented 1 year ago

It returns False because the data is not actually replicated. If you check the sharding, there is a UnStacked(8) in it which means that the value is not actually replicated but indeed sharded.

I'll figure out how to fix this.

EltayebAhmed commented 1 year ago

Any updates on this? I am also affected.

yashk2810 commented 1 year ago

I would suggest to move to shard_map. pmap is in maintenance mode and these problems are specific to pmap and PmapSharding.