ASEM000 / kernex

Stencil computations in JAX
MIT License
66 stars 3 forks source link

add pmap/laxmap #10

Closed ASEM000 closed 1 year ago

ASEM000 commented 1 year ago

Enable jax.lax.map/ jax.pmap in the kmap/smap interface Example:


import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=200"

import jax
import jax
import kernex as kex

@kex.kmap(
    kernel_size=(2,),
    map_kind="pmap",
    map_kwargs={"axis_name": "i"},
)
def f(x):
    return x

print(f(jax.numpy.arange(5)))

# [[0 1]
#  [1 2]
#  [2 3]
#  [3 4]]