ASEM000 / kernex

Stencil computations in JAX
MIT License
66 stars 3 forks source link

Document border handling with padding #12

Open clemisch opened 1 year ago

clemisch commented 1 year ago

How does Kernex handle borders and out-of-bounds indices, especially with padding="same"? I skimmed the source code but did not find it. Very superficially it seems to me that it performs 0-padding. Is this the case?

If yes, it would be useful to support different schemes like "mirror", "edge", or ideally dropping the out-of-bounds indices.

Repro:

from pylab import *
import jax.numpy as jnp
import kernex

@kernex.kmap(kernel_size=(11, 11), padding="same")
def kernel(patch):
    return jnp.mean(patch)

data = ones((100, 100))
out = kernel(data)

figure()
plot(out[50], "+-", label="Kernex")
plot(ones_like(out[50]), "+-", label="Ideal")
legend()

image

ASEM000 commented 1 year ago

it seems to me that it performs 0-padding. Is this the case?

Yes,

If yes, it would be useful to support different schemes like "mirror", "edge", or ideally dropping the out-of-bounds indices.

I think it's simple to do this. I am thinking about something like this :


@kex.kmap(pad_kwargs={"value":0,  ... })
... 

WDYT?

clemisch commented 1 year ago

Thanks, that looks good! Do I understand correctly that Kernex always uses array padding for border handling?

Do you have experience with performance of padding vs. adapting the computed indices at the borders? If the latter is even possible with JAX and/or beneficial on accelerators.

FYI, it's possible to control how JAX handles out-of-bounds indices with Array.at[ind].get(mode=...). The default clips into the valid range, which should be analogous to mode=edge in np.pad.

ASEM000 commented 1 year ago

Hello,

I extensively experimented with exploiting the mode parameter in .at (~jax 0.2) to avoid padding; surprisingly, the padding solution was faster on CPU/GPU for my setups.

How does this perform now with the new Jax versions?

clemisch commented 1 year ago

Ah, that's very interesting!

I asked because I don't have good evidence myself. Naively, padding seems inefficient, but I don't know how it works under jax.jit and how the memory/flops tradeoff behaves especially on GPU.

IIRC (very anecdotally) using the ideal .at(mode="drop") was noticably slower, maybe even x2 in my use cases. This was completely outside Kernex.

IMHO if padding is reasonably efficient and removes the need for custom border handling code, then it is the correct choice for Kernex :-) Exposing the padding mode would still be very handy though.

clemisch commented 1 year ago

PS: It could be interesting to compare performance of

  1. Kernex' padding="same" with array padding, to
  2. no padding/border handling at all and .at(mode="clip") mimicking pad(mode="edge").

If (2) is significantly faster (and you deem this relevant) Kernex could use it as the default for padding="same".

ASEM000 commented 1 year ago

Ok, I will look into it and let you know.

ASEM000 commented 1 year ago

@clemisch feel free to reopen if you still have questions.

clemisch commented 1 year ago

Edit: see https://github.com/ASEM000/Kernex/issues/14

clemisch commented 1 year ago

FYI, I can't re-open. I created a new issue: https://github.com/ASEM000/Kernex/issues/14