ASEM000 / kernex

Stencil computations in JAX
MIT License
66 stars 3 forks source link
jax kernel stencil

Differentiable Stencil computations in JAX

[**Installation**](#Installation) |[**Description**](#Description) |[**Quick example**](#QuickExample) |[**More Examples**](#MoreExamples) |[**Benchmarking**](#Benchmarking) ![Tests](https://github.com/ASEM000/kernex/actions/workflows/tests.yml/badge.svg) ![pyver](https://img.shields.io/badge/python-3.8%203.8%203.9%203.11-red) ![codestyle](https://img.shields.io/badge/codestyle-black-black) [![Downloads](https://static.pepy.tech/badge/kernex)](https://pepy.tech/project/kernex) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14UEqKzIyZsDzQ9IMeanvztXxbbbatTYV?usp=sharing) [![codecov](https://codecov.io/gh/ASEM000/kernex/branch/main/graph/badge.svg?token=3KLL24Z94I)](https://codecov.io/gh/ASEM000/kernex) [![DOI](https://zenodo.org/badge/512400616.svg)](https://zenodo.org/badge/latestdoi/512400616)

šŸ› ļø Installation

pip install kernex

šŸ“– Description

Kernex extends jax.vmap/jax.lax.map/jax.pmap with kmap and jax.lax.scan with kscan for general stencil computations.

ā© Quick Example

kmap kscan
```python import kernex as kex import jax.numpy as jnp @kex.kmap(kernel_size=(3,)) def sum_all(x): return jnp.sum(x) x = jnp.array([1,2,3,4,5]) print(sum_all(x)) # [ 6 9 12] ``` ```python import kernex as kex import jax.numpy as jnp @kex.kscan(kernel_size=(3,)) def sum_all(x): return jnp.sum(x) x = jnp.array([1,2,3,4,5]) print(sum_all(x)) # [ 6 13 22] ````
`jax.vmap` is used to sum each window content. `lax.scan` is used to update the array and the window sum is calculated sequentially. the first three rows represents the three sequential steps used to get the solution in the last row.

šŸ”¢ More examples

1ļøāƒ£ Convolution operation ```python import jax import jax.numpy as jnp import kernex as kex @jax.jit @kex.kmap( kernel_size= (3,3,3), padding = ('valid','same','same')) def kernex_conv2d(x,w): # JAX channel first conv2d with 3x3x3 kernel_size return jnp.sum(x*w) ```
2ļøāƒ£ Laplacian operation ```python # see also # https://numba.pydata.org/numba-doc/latest/user/stencil.html#basic-usage import jax import jax.numpy as jnp import kernex as kex @kex.kmap( kernel_size=(3,3), padding= 'valid', relative=True) # `relative`= True enables relative indexing def laplacian(x): return ( 0*x[1,-1] + 1*x[1,0] + 0*x[1,1] + 1*x[0,-1] +-4*x[0,0] + 1*x[0,1] + 0*x[-1,-1] + 1*x[-1,0] + 0*x[-1,1] ) print(laplacian(jnp.ones([10,10]))) # [[0., 0., 0., 0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0., 0., 0., 0.], # [0., 0., 0., 0., 0., 0., 0., 0.]] ```
3ļøāƒ£ Get Patches of an array ```python import jax import jax.numpy as jnp import kernex as kex @kex.kmap(kernel_size=(3,3),relative=True) def identity(x): # similar to numba.stencil # this function returns the top left cell in the padded/unpadded kernel view # or center cell if `relative`=True return x[0,0] # unlike numba.stencil , vector output is allowed in kernex # this function is similar to # `jax.lax.conv_general_dilated_patches(x,(3,),(1,),padding='same')` @jax.jit @kex.kmap(kernel_size=(3,3),padding='same') def get_3x3_patches(x): # returns 5x5x3x3 array return x mat = jnp.arange(1,26).reshape(5,5) print(mat) # [[ 1 2 3 4 5] # [ 6 7 8 9 10] # [11 12 13 14 15] # [16 17 18 19 20] # [21 22 23 24 25]] # get the view at array index = (0,0) print(get_3x3_patches(mat)[0,0]) # [[0 0 0] # [0 1 2] # [0 6 7]] ```
4ļøāƒ£ Linear convection
Problem setup Stencil view
```python import jax import jax.numpy as jnp import kernex as kex import matplotlib.pyplot as plt # see https://nbviewer.org/github/barbagroup/CFDPython/blob/master/lessons/01_Step_1.ipynb tmax,xmax = 0.5,2.0 nt,nx = 151,51 dt,dx = tmax/(nt-1) , xmax/(nx-1) u = jnp.ones([nt,nx]) c = 0.5 # kscan moves sequentially in row-major order and updates in-place using lax.scan. F = kernex.kscan( kernel_size = (3,3), padding = ((1,1),(1,1)), # n for time axis , i for spatial axis (optional naming) named_axis={0:'n',1:'i'}, relative=True ) # boundary condtion as a function def bc(u): return 1 # initial condtion as a function def ic1(u): return 1 def ic2(u): return 2 def linear_convection(u): return ( u['i','n-1'] - (c*dt/dx) * (u['i','n-1'] - u['i-1','n-1']) ) F[:,0] = F[:,-1] = bc # assign 1 for left and right boundary for all t # square wave initial condition F[:,:int((nx-1)/4)+1] = F[:,int((nx-1)/2):] = ic1 F[0:1, int((nx-1)/4)+1 : int((nx-1)/2)] = ic2 # assign linear convection function for # interior spatial location [1:-1] # and start from t>0 [1:] F[1:,1:-1] = linear_convection kx_solution = F(jnp.array(u)) plt.figure(figsize=(20,7)) for line in kx_solution[::20]: plt.plot(jnp.linspace(0,xmax,nx),line) ```
5ļøāƒ£ Gaussian blur ```python import jax import jax.numpy as jnp import kernex as kex def gaussian_blur(image, sigma, kernel_size): x = jnp.linspace(-(kernel_size - 1) / 2.0, (kernel_size- 1) / 2.0, kernel_size) w = jnp.exp(-0.5 * jnp.square(x) * jax.lax.rsqrt(sigma)) w = jnp.outer(w, w) w = w / w.sum() @kex.kmap(kernel_size=(kernel_size, kernel_size), padding="same") def conv(x): return jnp.sum(x * w) return conv(image) ```
6ļøāƒ£ Depthwise convolution ```python import jax import jax.numpy as jnp import kernex as kex @jax.jit @jax.vmap @kex.kmap( kernel_size= (3,3), padding = ('same','same')) def kernex_depthwise_conv2d(x,w): return jnp.sum(x*w) h,w,c = 5,5,2 k=3 x = jnp.arange(1,h*w*c+1).reshape(c,h,w) w = jnp.arange(1,k*k*c+1).reshape(c,k,k) print(kernex_depthwise_conv2d(x,w)) ````
7ļøāƒ£ Average pooling 2D ```python @jax.vmap # vectorize over the channel dimension @kex.kmap(kernel_size=(3,3), strides=(2,2)) def avgpool_2d(x): # define the kernel for the Average pool operation over the spatial dimensions return jnp.mean(x) ````
8ļøāƒ£ Runge-Kutta integration ```python # lets solve dydt = y, where y0 = 1 and y(t)=e^t # using Runge-Kutta 4th order method # f(t,y) = y import jax.numpy as jnp import matplotlib.pyplot as plt import kernex as kex t = jnp.linspace(0, 1, 5) y = jnp.zeros(5) x = jnp.stack([y, t], axis=0) dt = t[1] - t[0] # 0.1 f = lambda tn, yn: yn def ic(x): """ initial condition y0 = 1 """ return 1. def rk4(x): """ runge kutta 4th order integration step """ # ā”Œā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā” ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”€ā”€ā” # ā”‚ y0 ā”‚*y1*ā”‚ y2 ā”‚ ā”‚[0,-1]ā”‚[0, 0]ā”‚[0, 1]ā”‚ # ā”œā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¤ ==> ā”œā”€ā”€ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”€ā”€ā”¤ # ā”‚ t0 ā”‚ t1 ā”‚ t2 ā”‚ ā”‚[1,-1]ā”‚[1, 0]ā”‚[1, 1]ā”‚ # ā””ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”˜ ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”€ā”€ā”“ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ t0 = x[1, -1] y0 = x[0, -1] k1 = dt * f(t0, y0) k2 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k1) k3 = dt * f(t0 + dt / 2, y0 + 1 / 2 * k2) k4 = dt * f(t0 + dt, y0 + k3) yn_1 = y0 + 1 / 6 * (k1 + 2 * k2 + 2 * k3 + k4) return yn_1 F = kex.kscan(kernel_size=(2, 3), relative=True, padding=((0, 1))) # kernel size = 3 F[0:1, 1:] = rk4 F[0, 0] = ic # compile the solver solver = jax.jit(F.__call__) y = solver(x)[0, :] plt.plot(t, y, '-o', label='rk4') plt.plot(t, jnp.exp(t), '-o', label='analytical') plt.legend() ``` ![img](assets/rk4.svg)