stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
140 stars 9 forks source link

Adaptive Pooling #50

Open dlwh opened 8 months ago

dlwh commented 8 months ago

Adaptive Pooling is a pain in JAX/XLA, at least if you want to get the same results as torch in the general case. (Equinox doesn't aim for torch equivalence here, and other major JAX frameworks don't even implement it, except sometimes for the boring case when input % output == 0)

https://stackoverflow.com/a/63603993/1736826 seems to be a correct description of how it works, at least in the 1-d case. You end up with overlapping windows of differing sizes, which can't be turned into a reduce_window call (I think?) I think you can maybe get it done with a second quasi-mask argument to reduce window, but i haven't figured it out yet.

Maybe going for torch-equivalence here isn't worth it? Maybe it can be done w/ Pallas?

Seems like the easiest thing would be to do ~symmetric padding but no one seems to do it that way?