Open Conchylicultor opened 10 months ago
I don't think this is related to with_sharding_constraint
. In my experience every operation applied to an object with NamedSharding
transforms it into an object with GSPMDSharding
loosing the PartitionSpec names (but still being sharded correctly).
import os
import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
num_gpus = 8
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))
with mesh:
arr = jax.numpy.zeros((16,))
arr = jax.device_put(arr, NamedSharding(mesh, P("gpus")))
print(arr.sharding)
arr = arr * 2
print(arr.sharding)
print(arr.sharding.devices_indices_map(tuple(arr.shape)))
jax.debug.visualize_array_sharding(arr,use_color=False)
NamedSharding(mesh={'gpus': 8}, spec=PartitionSpec('gpus',))
GSPMDSharding({devices=[8]0,1,2,3,4,5,6,7})
{CpuDevice(id=0): (slice(0, 2, None),), CpuDevice(id=1): (slice(2, 4, None),), CpuDevice(id=2): (slice(4, 6, None),), CpuDevice(id=3): (slice(6, 8, None),), CpuDevice(id=4): (slice(8, 10, None),), CpuDevice(id=5): (slice(10, 12, None),), CpuDevice(id=6): (slice(12, 14, None),), CpuDevice(id=7): (slice(14, 16, None),)}
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
Yeah, this should ideally work and return NamedSharding.
This is a rare case with no arguments and no annotations on jit when it doesn't work.
But thanks for reporting. I'll try to fix it.
Hi @Findus23
It looks the issue mentioned by you has been resolved. I tried to reproduce the issue mentioned by you on Colab with JAX version 0.4.23. Now the operations applied to an object with NamedSharding
doesn't transform it into an object with GSPMDSharding
.
But the issue mentioned by @Conchylicultor still exists.
import os
import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
num_gpus = 8
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))
with mesh:
arr = jax.numpy.zeros((16,))
arr = jax.device_put(arr, NamedSharding(mesh, P("gpus")))
print(arr.sharding)
arr = arr * 2
print(arr.sharding)
arr = arr + 4
print(arr.sharding)
arr = arr - 1
print(arr.sharding)
arr = arr / 2
print(arr.sharding)
print(arr.sharding.devices_indices_map(tuple(arr.shape)))
jax.debug.visualize_array_sharding(arr,use_color=False)
Output:
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
{CpuDevice(id=0): (slice(0, 2, None),), CpuDevice(id=1): (slice(2, 4, None),), CpuDevice(id=2): (slice(4, 6, None),), CpuDevice(id=3): (slice(6, 8, None),), CpuDevice(id=4): (slice(8, 10, None),), CpuDevice(id=5): (slice(10, 12, None),), CpuDevice(id=6): (slice(12, 14, None),), CpuDevice(id=7): (slice(14, 16, None),)}
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
Please find the gist for reference.
Thank you
Indeed, it seems like since 0.4.19 most operations are returning proper NamedSharding: https://github.com/Findus23/jax-array-info/commit/fd641005656c3c23f9b854fe0e7992e9a5937864
I would expect the 2 functions
init0
andinit1
to be identical:However using
jax.lax.with_sharding_constraint
loose the original sharding information (e.g. sharding name) outside of jit:The reason why
jax.lax.with_sharding_constraint
is prefered is that executingout_shardings=
fail becauseabsl.app.run
is not called when the function is defined, sojax.devices()
and creating the sharding fail.