Open shoyer opened 1 year ago
Hello Stephan,
I would love to understand at a high level how this package works
jax.vmap
: https://github.com/ASEM000/kernex/blob/f5dd7f2bf7de70082da3ea1b402af14f1fd362ee/kernex/_src/utils.py#L133-L163
https://github.com/ASEM000/kernex/blob/f5dd7f2bf7de70082da3ea1b402af14f1fd362ee/kernex/_src/base.py#L63-L77
jnp.ix_
then applies the user function on it. In the case of relative=True
or in other words, the indexing is relative (center is 0 , like numba.stencil
), then I roll the array portion before applying the function.For kmap
, I usejax.vmap
to vectorize the new view indices-accepting function over array of all possible view indices.
https://github.com/ASEM000/kernex/blob/f5dd7f2bf7de70082da3ea1b402af14f1fd362ee/kernex/_src/map.py#L37-L51
For kscan
- my prime motivation- I use jax.lax.scan
to scan the indices array
https://github.com/ASEM000/kernex/blob/f5dd7f2bf7de70082da3ea1b402af14f1fd362ee/kernex/_src/scan.py#L36-L47
Does it support auto-diff?
Yes, definitely, the library relies on jax.numpy
, jax.vmap
, jax.lax.scan
, and jax.lax.switch
for it's internals.
How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?
I benchmarked jax.lax.conv_general_dilated_patches
and jax.lax.conv_general_dilated
for based on kmap
The code is under tests_and_benchmarks
. In general, kmap seems faster for many scenarios, especially on CPU*, However it needs more rigorous benchmarking, especially on TPU.
In general, my prime motivation is to solve PDEs using a stencil definition, which might require applying different functions at different locations of the array (ex., boundary),
This is the reason kernex
offers the ability to use kmap
and kscan
along with jax.lax.switch
to apply different functions on different portions of the array. The following example introduces the function mesh concept, where different stencils can be applied using indexing. The backbone for this feature is jax.lax.switch
Function mesh | Array equivalent |
```python F = kex.kmap(kernel_size=(1,)) F[0] = lambda x:x[0]**2 F[1:] = lambda x:x[0]**3 array = jnp.arange(1,11).astype('float32') print(F(array)) >>> [1., 8., 27., 64., 125., ... 216., 343., 512., 729., 1000.] print(jax.grad(lambda x:jnp.sum(F(x)))(array)) >>> [2.,12.,27.,48.,75., ... 108.,147.,192.,243.,300.] ``` | ```python def F(x): f1 = lambda x:x**2 f2 = lambda x:x**3 x = x.at[0].set(f1(x[0])) x = x.at[1:].set(f2(x[1:])) return x array = jnp.arange(1,11).astype('float32') print(F(array)) >>> [1., 8., 27., 64., 125., ... 216., 343., 512., 729., 1000.] print(jax.grad(lambda x: jnp.sum(F(x)))(array)) >>> [2.,12.,27.,48.,75., ... 108.,147.,192.,243.,300.] ``` |
OK great, thank you for sharing!
I agree that this is a very promising approach for implementing PDE kernels, and in general this is similar to the way I've implemented PDE solvers in JAX by hand (e.g., the wave equation solver).
conv_general_dilated
and conv_general_dilated_patches
use XLA's Convolution
operation, which is really optimized for convolutional neural networks with large numbers of channels. I wouldn't expect them to work well for PDE kernels, except perhaps on TPUs.
This project looks really cool!
I would love to understand at a high level how this package works -- how do you actually implement stencil computations in JAX? Do you reuse
jax.lax.scan
or something else? Does it support auto-diff? How does performance compare on CPU/GPU/TPU (or whichever configs you've tried)?