Closed ASEM000 closed 1 year ago
Enable jax.lax.map/ jax.pmap in the kmap/smap interface Example:
jax.lax.map
jax.pmap
kmap
smap
import os os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=200" import jax import jax import kernex as kex @kex.kmap( kernel_size=(2,), map_kind="pmap", map_kwargs={"axis_name": "i"}, ) def f(x): return x print(f(jax.numpy.arange(5))) # [[0 1] # [1 2] # [2 3] # [3 4]]
Enable
jax.lax.map
/jax.pmap
in thekmap
/smap
interface Example: