ASEM000 / kernex

Stencil computations in JAX
MIT License
66 stars 3 forks source link

Add padding _kwargs #13

Closed ASEM000 closed 1 year ago

ASEM000 commented 1 year ago

closes #12

Example

    @kex.kmap(
        kernel_size=(3,),
        padding=("same"),
        relative=False,
        padding_kwargs=dict(constant_values=10),
    )
    def f(x):
        return x

    x = jnp.array([1, 2, 3, 4, 5])

    np.testing.assert_allclose(
        f(x),
        np.array([[10, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 10]]),
    )