ASEM000 / kernex

Stencil computations in JAX
MIT License
66 stars 3 forks source link

Support Pytrees #11

Open clemisch opened 1 year ago

clemisch commented 1 year ago

Does kernex support Pytrees? I did not find an example. It would be very useful to support moving-window filters with "global" weights or simply multiple inputs, such as a cross-channel bilateral filter in my case.

Repro:

import jax.numpy as jnp
import kernex

@kernex.kmap(kernel_size=(3, 3))
def kernel(tree):
    x, y = tree
    return jnp.sum(x * jnp.square(y))

data = jnp.arange(20 * 30).reshape((20, 30))
out = kernel((data, data))

raises

Traceback (most recent call last):
  File "/home/clemisch/kernex_tree.py", line 52, in <module>
    out = kernel((data, data))
          ^^^^^^^^^^^^^^^^^^^^
  File "/home/clemisch/venvs/11/lib64/python3.11/site-packages/kernex/interface/kernel_interface.py", line 131, in call
    self.shape = array.shape
                 ^^^^^^^^^^^
AttributeError: 'tuple' object has no attribute 'shape'
ASEM000 commented 1 year ago

Hello, Thanks for your question. This is a reasonable request; I will try to look into it when I have time.

ASEM000 commented 1 year ago

Hello, meanwhile, can you try this ?

The key point here is to stack the arrays on some axis i and make the kernel size for that axis i equal to the same size as the axis size with valid padding for that axis. In this example, i is the first axis.

I also recommend using jax.debug.print to ensure the array views are what you are looking for.


import jax.numpy as jnp
import kernex
import jax

@kernex.kmap(kernel_size=(2, 3, 3), padding=("valid","valid","valid"))
def kernel(tree):
    x, y = tree
    jax.debug.print("x={x} \n\n y={y}\n",x=x, y=y)
    return jnp.sum(x * jnp.square(y))

data = jnp.arange(25).reshape(5, 5)
out = kernel(jnp.stack([data, data],axis=0))

# x=[[ 0  1  2]
#  [ 5  6  7]
#  [10 11 12]] 

#  y=[[ 0  1  2]
#  [ 5  6  7]
#  [10 11 12]]

# x=[[ 1  2  3]
#  [ 6  7  8]
#  [11 12 13]] 

#  y=[[ 1  2  3]
#  [ 6  7  8]
#  [11 12 13]]

# x=[[ 2  3  4]
#  [ 7  8  9]
#  [12 13 14]] 

#  y=[[ 2  3  4]
#  [ 7  8  9]
#  [12 13 14]]

# x=[[ 5  6  7]
#  [10 11 12]
#  [15 16 17]] 

#  y=[[ 5  6  7]
#  [10 11 12]
#  [15 16 17]]

# x=[[ 6  7  8]
#  [11 12 13]
#  [16 17 18]] 

#  y=[[ 6  7  8]
#  [11 12 13]
#  [16 17 18]]

# x=[[ 7  8  9]
#  [12 13 14]
#  [17 18 19]] 

#  y=[[ 7  8  9]
#  [12 13 14]
#  [17 18 19]]

# x=[[10 11 12]
#  [15 16 17]
#  [20 21 22]] 

#  y=[[10 11 12]
#  [15 16 17]
#  [20 21 22]]

# x=[[11 12 13]
#  [16 17 18]
#  [21 22 23]] 

#  y=[[11 12 13]
#  [16 17 18]
#  [21 22 23]]

# x=[[12 13 14]
#  [17 18 19]
#  [22 23 24]] 

#  y=[[12 13 14]
#  [17 18 19]
#  [22 23 24]]
clemisch commented 1 year ago

Thanks, that works for me!

ASEM000 commented 1 year ago

As a follow-up, I think it is simpler to define which argnums to generate kernel. For the previous example maybe the API would be something like this kmap(.., argnums=(0,1))(lambda x,y: ... )

What do you think?

clemisch commented 1 year ago

Thanks for the follow-up and including me in this.

To clarify, do you mean not supporting trees, but instead multiple arguments? So something like

@kernex.kmap(kernel_size=(3, 3), argnums=(0, 1))
def kernel(x, y):
    return jnp.sum(x * jnp.square(y))

, or for non-mapped local weights

@kernex.kmap(kernel_size=(3, 3), argnums=(0,))
def kernel(x, y_local):
    return jnp.sum(x * jnp.square(y_local))

where y_local would not be mapped over y but a constant (3,3) array.

TLDR: Anything is fine for me. I think supporting trees would be slightly more powerful, but any reasonable task should be translatable to multiple args instead of a tree.