jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

shard_map - NotImplementedError: No replication rule for cond #24418

Open giovannicemin opened 3 weeks ago

giovannicemin commented 3 weeks ago

Description

Description

The following (very minimal) example gives this error:

NotImplementedError: No replication rule for cond. As a workaround, pass the check_rep=False argument to shard_map. To get this fixed, open an issue at https://github.com/google/jax/issues

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=1'
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

mesh = Mesh(jax.devices(), axis_names=('i',))
sharding = NamedSharding(mesh, P('i'))

a = jnp.array([True, False])

@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'))
def f(a):
  c = jax.lax.cond(a[0], lambda: 0, lambda: 1)
  return c

f(a)

Note I set --xla_force_host_platform_device_count=1 for reproducibility but the same happens for other device counts.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.28 jaxlib: 0.4.28.dev20240710 numpy: 1.26.4 python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1

ASKabalan commented 3 weeks ago

Hello,

This will you give you the desired behaviour

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=1'
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

mesh = Mesh(jax.devices(), axis_names=('i',))
sharding = NamedSharding(mesh, P('i'))

a = jnp.array([False, False])
b = jnp.array([True, False])
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'),check_rep=False)
def f(a):
  c = jax.lax.cond(a[0], lambda: 0, lambda: 1)
  return jnp.atleast_1d(c)

f(a)
# Out[7]: Array([1], dtype=int32, weak_type=True)

f(b)
# Out[8]: Array([0], dtype=int32, weak_type=True)
giovannicemin commented 2 weeks ago

Thanks @ASKabalan for the answer; Yes, I know this gets me the desired behaviour, however, as reported in the error this is a workaround. I opened the issue so that someone can fix the error.