google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.06k stars 2.66k forks source link

`jax.lax.with_sharding_constraint` is not propagated outside jit #17422

Open Conchylicultor opened 10 months ago

Conchylicultor commented 10 months ago

I would expect the 2 functions init0 and init1 to be identical:

REPLICATED = NamedSharding(mesh=Mesh('devices': 8), spec=PartitionSpec())

@functools.partial(
    jax.jit,
    out_shardings=REPLICATED,
)
def init0():
  return jnp.zeros((3,))

@jax.jit
def init1():
  return jax.lax.with_sharding_constraint(jnp.zeros((3,)), REPLICATED)

However using jax.lax.with_sharding_constraint loose the original sharding information (e.g. sharding name) outside of jit:

init0().sharding  # NamedSharding(mesh=Mesh('devices': 8), spec=PartitionSpec())
init1().sharding  # GSPMDSharding({replicated})

The reason why jax.lax.with_sharding_constraint is prefered is that executing out_shardings= fail because absl.app.run is not called when the function is defined, so jax.devices() and creating the sharding fail.

Findus23 commented 10 months ago

I don't think this is related to with_sharding_constraint. In my experience every operation applied to an object with NamedSharding transforms it into an object with GSPMDSharding loosing the PartitionSpec names (but still being sharded correctly).

import os

import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

num_gpus = 8

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))

with mesh:
    arr = jax.numpy.zeros((16,))
    arr = jax.device_put(arr, NamedSharding(mesh, P("gpus")))
    print(arr.sharding)
    arr = arr * 2
    print(arr.sharding)
    print(arr.sharding.devices_indices_map(tuple(arr.shape)))
    jax.debug.visualize_array_sharding(arr,use_color=False)
NamedSharding(mesh={'gpus': 8}, spec=PartitionSpec('gpus',))
GSPMDSharding({devices=[8]0,1,2,3,4,5,6,7})
{CpuDevice(id=0): (slice(0, 2, None),), CpuDevice(id=1): (slice(2, 4, None),), CpuDevice(id=2): (slice(4, 6, None),), CpuDevice(id=3): (slice(6, 8, None),), CpuDevice(id=4): (slice(8, 10, None),), CpuDevice(id=5): (slice(10, 12, None),), CpuDevice(id=6): (slice(12, 14, None),), CpuDevice(id=7): (slice(14, 16, None),)}
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
yashk2810 commented 10 months ago

Yeah, this should ideally work and return NamedSharding.

This is a rare case with no arguments and no annotations on jit when it doesn't work.

But thanks for reporting. I'll try to fix it.

rajasekharporeddy commented 4 months ago

Hi @Findus23

It looks the issue mentioned by you has been resolved. I tried to reproduce the issue mentioned by you on Colab with JAX version 0.4.23. Now the operations applied to an object with NamedSharding doesn't transform it into an object with GSPMDSharding.

But the issue mentioned by @Conchylicultor still exists.

import os

import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

num_gpus = 8

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

devices = mesh_utils.create_device_mesh((num_gpus,))
mesh = Mesh(devices, axis_names=('gpus',))

with mesh:
    arr = jax.numpy.zeros((16,))
    arr = jax.device_put(arr, NamedSharding(mesh, P("gpus")))
    print(arr.sharding)
    arr = arr * 2
    print(arr.sharding)
    arr = arr + 4
    print(arr.sharding)
    arr = arr - 1
    print(arr.sharding)
    arr = arr / 2
    print(arr.sharding)
    print(arr.sharding.devices_indices_map(tuple(arr.shape)))
    jax.debug.visualize_array_sharding(arr,use_color=False)

Output:

NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
NamedSharding(mesh=Mesh('gpus': 8), spec=PartitionSpec('gpus',))
{CpuDevice(id=0): (slice(0, 2, None),), CpuDevice(id=1): (slice(2, 4, None),), CpuDevice(id=2): (slice(4, 6, None),), CpuDevice(id=3): (slice(6, 8, None),), CpuDevice(id=4): (slice(8, 10, None),), CpuDevice(id=5): (slice(10, 12, None),), CpuDevice(id=6): (slice(12, 14, None),), CpuDevice(id=7): (slice(14, 16, None),)}
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

Please find the gist for reference.

Thank you

Findus23 commented 4 months ago

Indeed, it seems like since 0.4.19 most operations are returning proper NamedSharding: https://github.com/Findus23/jax-array-info/commit/fd641005656c3c23f9b854fe0e7992e9a5937864