google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.66k stars 2.71k forks source link

Implement jax.stencil similar to numba.stencil #11617

Open ASEM000 opened 2 years ago

ASEM000 commented 2 years ago

Stencil computations are quite common in computational sciences and image processing. Similar to numba.stencil*,** , I propose adding jax.stencil that exploits jax.vmap on the backend.

jakevdp commented 2 years ago

Thanks for the request - this sounds like a useful API – If I understand correctly, it's a bit more involved than just wrapping vmap, because the mapping would be over indices, which are not passed as arguments to the function but rather are defined implicitly by the presence of indexing statements. For that reason, the best approach to implementing this in JAX may be via a custom JAX transformation that traces the function, then modifies all indexing statements at the jaxpr level. Still, that might be somewhat challenging to implement in JAX given the semantics of how indexing is implemented (it lowers to XLA scatter/gather, which has somewhat complicated semantics). All that said, it could be a fun experiment to try.

jakevdp commented 2 years ago

Well, now you've gone and nerd-sniped me and I can't stop thinking about how I'd implement this. I'll give it a shot.

ASEM000 commented 2 years ago

Hi jake ,

I have implemented something similar to numba.stencil here * .However, I submitted the issue because I believe my implementation is not optimal compared with a dedicated transformation .

For my implementation, I generate a sampling matrix of the array indices given the array shape and kernel size; then, I use vmap to vectorize over the indices.

The critical part of my code is using something like this : jax.vmap(lambda view:F(jnp.roll(array[jnp.ix_(*view)]))(sampling_matrix)

Where F is the decorated function by @stencil, view defines indices of the current window jnp.roll is used to give relative indexing

This can also be achieved by vmaping over each patch of jax.lax.conv_general_dilated_patches of the array something similar to this :

patches = jax.lax.conv_general_dilated_patches( ... )
jax.vmap(lambda patch : F(jnp.roll(patch))(patches)

I tried to use as many lax primitives as possible in my implementation, but I believe a custom transformation would be optimal here, as you mentioned.

*https://github.com/ASEM000/kernex

jakevdp commented 2 years ago

That's cool - it's an interesting idea to use roll along with vmap, but I suspect it's ends up being somewhat inefficient

ASEM000 commented 2 years ago

Hey, Jake

Here is a minimal working example of what I implemented

# 5x5 array with a kernel size of 3x3, strides of 1x1, and 0 padding
# output shape is 9x3x3 for the nine patches (would be 25 patches with 1x1 padding + padded array )
# the first element in the leading axis represent the patch at index=(1,1)

import jax
import jax.numpy as jnp
import functools

@functools.partial(jax.profiler.annotate_function, name="roll_view")
def roll_view(array: jnp.ndarray) -> jnp.ndarray:
    """Roll view along all axes

    Example:
    >>> x = jnp.arange(1,26).reshape(5,5)
    >>> print(roll_view(x))
        [[13 14 15 11 12]
        [18 19 20 16 17]
        [23 24 25 21 22]
        [ 3  4  5  1  2]
        [ 8  9 10  6  7]]
    """
    shape = jnp.array(array.shape)
    axes = tuple(range(len(shape)))  # list all axes
    shift = tuple(
        -(si // 2) if si % 2 == 1 else -((si - 1) // 2) for si in array.shape
    )  # right padding>left padding
    return jnp.roll(array, shift=shift, axis=axes)

@functools.partial(jax.profiler.annotate_function, name="general_arange")
def general_arange(di: int, ki: int, si: int, x0: int, xf: int) -> jnp.ndarray:
    """Calculate the windows indices for a given dimension.

    Args:
        di (int): shape of the dimension
        ki (int): kernel size
        si (int): stride
        x0 (int): left padding
        xf (int): rght padding

    Returns:
        jnp.ndarray: array of windows indices

    Example:
        >>> di = 5  # array of shape (5,)
        >>> ki = 3  # kernel_size = (3,)
        >>> si = 1  # stride = 1 
        >>> x0 = 0  # left padding = 0
        >>> xf = 0  # right padding = 0

        >>> print(general_arange(di, ki, si, x0, xf))
            [[0 1 2]
            [1 2 3]
            [2 3 4]]
    """
    start, end = -x0 + ((ki - 1) // 2), di + xf - (ki // 2)
    size = end - start
    lhs = jax.lax.broadcasted_iota(dtype=jnp.int32, shape=(size, ki), dimension=0) + (start)  # fmt: skip
    rhs = jax.lax.broadcasted_iota(dtype=jnp.int32, shape=(ki, size), dimension=0).T - ((ki - 1) // 2)  # fmt: skip
    res = lhs + rhs

    # res[::si] is slightly slower.
    return (res) if si == 1 else (res)[::si]

@functools.partial(jax.profiler.annotate_function, name="general_product")
def general_product(*args):
    """Equivalent to tuple(zip(*itertools.product(*args)))` for arrays

    Example:
    >>> general_product(
    ... jnp.array([[1,2],[3,4]]),
    ... jnp.array([[5,6],[7,8]]))
    (
        DeviceArray([[[1, 2],[1, 2]],[[3, 4],[3, 4]]], dtype=int32),
        DeviceArray([[[5, 6],[7, 8]],[[5, 6],[7, 8]]], dtype=int32)
    )

    >>> tuple(zip(*(itertools.product([[1,2],[3,4]],[[5,6],[7,8]]))))
    (
        ([1, 2], [1, 2], [3, 4], [3, 4]), 
        ([5, 6], [7, 8], [5, 6], [7, 8])
    )

    """
    def nvmap(n):
        in_axes = [None] * len(args)
        in_axes[-n] = 0
        return (
            jax.vmap(lambda *x: x, in_axes=in_axes)
            if n == 1
            else jax.vmap(nvmap(n - 1), in_axes=in_axes)
        )

    return nvmap(len(args))(*args)

kernel_size = (3,3)

row_window_indices = general_arange(di=5, ki=3, si=1, x0=0, xf=0)
col_window_indices = general_arange(di=5, ki=3, si=1, x0=0, xf=0)

patches_indices = general_product(row_window_indices,col_window_indices)
patches_indices = tuple(map(lambda xi, wi: xi.reshape(-1, wi), patches_indices, kernel_size))

x = jnp.arange(1,26).reshape((5,5))

print(jax.vmap(lambda view: roll_view(x[jnp.ix_(*view)]))(patches_indices))
[[[ 7  8  6]
  [12 13 11]
  [ 2  3  1]]

 [[ 8  9  7]
  [13 14 12]
  [ 3  4  2]]

 [[ 9 10  8]
  [14 15 13]
  [ 4  5  3]]

 [[12 13 11]
  [17 18 16]
  [ 7  8  6]]

 [[13 14 12]
  [18 19 17]
  [ 8  9  7]]

 [[14 15 13]
  [19 20 18]
  [ 9 10  8]]

 [[17 18 16]
  [22 23 21]
  [12 13 11]]

 [[18 19 17]
  [23 24 22]
  [13 14 12]]

 [[19 20 18]
  [24 25 23]
  [14 15 13]]]