Open clemisch opened 1 year ago
it seems to me that it performs 0-padding. Is this the case?
Yes,
If yes, it would be useful to support different schemes like "mirror", "edge", or ideally dropping the out-of-bounds indices.
I think it's simple to do this. I am thinking about something like this :
@kex.kmap(pad_kwargs={"value":0, ... })
...
WDYT?
Thanks, that looks good! Do I understand correctly that Kernex always uses array padding for border handling?
Do you have experience with performance of padding vs. adapting the computed indices at the borders? If the latter is even possible with JAX and/or beneficial on accelerators.
FYI, it's possible to control how JAX handles out-of-bounds indices with Array.at[ind].get(mode=...)
. The default clips into the valid range, which should be analogous to mode=edge
in np.pad
.
Hello,
I extensively experimented with exploiting the mode parameter in .at (~jax 0.2) to avoid padding; surprisingly, the padding solution was faster on CPU/GPU for my setups.
How does this perform now with the new Jax versions?
Ah, that's very interesting!
I asked because I don't have good evidence myself. Naively, padding seems inefficient, but I don't know how it works under jax.jit
and how the memory/flops tradeoff behaves especially on GPU.
IIRC (very anecdotally) using the ideal .at(mode="drop")
was noticably slower, maybe even x2 in my use cases. This was completely outside Kernex.
IMHO if padding is reasonably efficient and removes the need for custom border handling code, then it is the correct choice for Kernex :-) Exposing the padding mode
would still be very handy though.
PS: It could be interesting to compare performance of
padding="same"
with array padding, to.at(mode="clip")
mimicking pad(mode="edge")
.If (2) is significantly faster (and you deem this relevant) Kernex could use it as the default for padding="same"
.
Ok, I will look into it and let you know.
@clemisch feel free to reopen if you still have questions.
FYI, I can't re-open. I created a new issue: https://github.com/ASEM000/Kernex/issues/14
How does Kernex handle borders and out-of-bounds indices, especially with
padding="same"
? I skimmed the source code but did not find it. Very superficially it seems to me that it performs 0-padding. Is this the case?If yes, it would be useful to support different schemes like "mirror", "edge", or ideally dropping the out-of-bounds indices.
Repro: