Open InbarItayG opened 1 year ago
It returns False because the data is not actually replicated. If you check the sharding, there is a UnStacked(8)
in it which means that the value is not actually replicated but indeed sharded.
I'll figure out how to fix this.
Any updates on this? I am also affected.
I would suggest to move to shard_map. pmap is in maintenance mode and these problems are specific to pmap and PmapSharding.
Description
Seems like the is_fully_replicated method is problematic. See colab: https://colab.corp.google.com/drive/1qIUvfu8OvdqbdEby8DdnzuNYb3dKOkBq?usp=sharing
The code we have there is : _jax_arr= jnp.array([1,2,3]) jax_arr.is_fully_replicated # True
jax_arr = flax.jax_utils.replicate(jax_arr) jax_arr.is_fullyreplicated **# False**
We would image it being the exact opposite outputs but both for a TPU and CPU runtimes we get that mysterious output.
What jax/jaxlib version are you using?
Latest
Which accelerator(s) are you using?
CPUt + TPU
Additional system info
Colab/Borg
NVIDIA GPU info
No response