ASEM000 / kernex

Stencil computations in JAX
MIT License
66 stars 3 forks source link

Implementation strategy #2

Open shoyer opened 1 year ago

shoyer commented 1 year ago

This project looks really cool!

I would love to understand at a high level how this package works -- how do you actually implement stencil computations in JAX? Do you reuse jax.lax.scan or something else? Does it support auto-diff? How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?

ASEM000 commented 1 year ago

Hello Stephan,

I would love to understand at a high level how this package works

for kmap https://github.com/ASEM000/kernex/blob/f5dd7f2bf7de70082da3ea1b402af14f1fd362ee/kernex/_src/map.py#L27-L35

for kscan https://github.com/ASEM000/kernex/blob/f5dd7f2bf7de70082da3ea1b402af14f1fd362ee/kernex/_src/scan.py#L23-L34


For kmap, I usejax.vmap to vectorize the new view indices-accepting function over array of all possible view indices. https://github.com/ASEM000/kernex/blob/f5dd7f2bf7de70082da3ea1b402af14f1fd362ee/kernex/_src/map.py#L37-L51

For kscan - my prime motivation- I use jax.lax.scan to scan the indices array https://github.com/ASEM000/kernex/blob/f5dd7f2bf7de70082da3ea1b402af14f1fd362ee/kernex/_src/scan.py#L36-L47

Does it support auto-diff?

Yes, definitely, the library relies on jax.numpy, jax.vmap, jax.lax.scan, and jax.lax.switch for it's internals.

How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?

I benchmarked jax.lax.conv_general_dilated_patches and jax.lax.conv_general_dilated for based on kmap The code is under tests_and_benchmarks. In general, kmap seems faster for many scenarios, especially on CPU*, However it needs more rigorous benchmarking, especially on TPU.

In general, my prime motivation is to solve PDEs using a stencil definition, which might require applying different functions at different locations of the array (ex., boundary), This is the reason kernex offers the ability to use kmap and kscan along with jax.lax.switch to apply different functions on different portions of the array. The following example introduces the function mesh concept, where different stencils can be applied using indexing. The backbone for this feature is jax.lax.switch

Function mesh Array equivalent
```python F = kex.kmap(kernel_size=(1,)) F[0] = lambda x:x[0]**2 F[1:] = lambda x:x[0]**3 array = jnp.arange(1,11).astype('float32') print(F(array)) >>> [1., 8., 27., 64., 125., ... 216., 343., 512., 729., 1000.] print(jax.grad(lambda x:jnp.sum(F(x)))(array)) >>> [2.,12.,27.,48.,75., ... 108.,147.,192.,243.,300.] ``` ```python def F(x): f1 = lambda x:x**2 f2 = lambda x:x**3 x = x.at[0].set(f1(x[0])) x = x.at[1:].set(f2(x[1:])) return x array = jnp.arange(1,11).astype('float32') print(F(array)) >>> [1., 8., 27., 64., 125., ... 216., 343., 512., 729., 1000.] print(jax.grad(lambda x: jnp.sum(F(x)))(array)) >>> [2.,12.,27.,48.,75., ... 108.,147.,192.,243.,300.] ```
shoyer commented 1 year ago

OK great, thank you for sharing!

I agree that this is a very promising approach for implementing PDE kernels, and in general this is similar to the way I've implemented PDE solvers in JAX by hand (e.g., the wave equation solver).

conv_general_dilated and conv_general_dilated_patches use XLA's Convolution operation, which is really optimized for convolutional neural networks with large numbers of channels. I wouldn't expect them to work well for PDE kernels, except perhaps on TPUs.