Open giovannicemin opened 3 weeks ago
Hello,
This will you give you the desired behaviour
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=1'
from functools import partial
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
mesh = Mesh(jax.devices(), axis_names=('i',))
sharding = NamedSharding(mesh, P('i'))
a = jnp.array([False, False])
b = jnp.array([True, False])
@partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'),check_rep=False)
def f(a):
c = jax.lax.cond(a[0], lambda: 0, lambda: 1)
return jnp.atleast_1d(c)
f(a)
# Out[7]: Array([1], dtype=int32, weak_type=True)
f(b)
# Out[8]: Array([0], dtype=int32, weak_type=True)
Thanks @ASKabalan for the answer; Yes, I know this gets me the desired behaviour, however, as reported in the error this is a workaround. I opened the issue so that someone can fix the error.
Description
Description
The following (very minimal) example gives this error:
NotImplementedError: No replication rule for cond. As a workaround, pass the
check_rep=False
argument toshard_map
. To get this fixed, open an issue athttps://github.com/google/jax/issues
Note I set --xla_force_host_platform_device_count=1 for reproducibility but the same happens for other device counts.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.28 jaxlib: 0.4.28.dev20240710 numpy: 1.26.4 python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] jax.devices (1 total, 1 local): [CpuDevice(id=0)] process_count: 1