jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.97k stars 2.75k forks source link

Implement higher order interpolation in jax's map_coordinates #3928

Open fischcheng opened 4 years ago

fischcheng commented 4 years ago

Hi, I've been trying to use Jax to speed up a program that uses scipy.ndimage.map_coordinates extensively. Was able to use jax's map_coordinates to get a significant speedup, as well as using jit to speed up even more when looping many times. Pretty amazing.

I am aware that the current map_coordinates in jax only supports some "modes" and up to order "1", but our program has been using order=3, Is there a plan to implement such anytime soon? I am trying to implement such feature but haven't made much progress so far.

https://github.com/scipy/scipy/blob/v1.5.2/scipy/ndimage/src/ni_splines.c

I went to the scipy.ndimage C-codes, and found the scipy.map_coordinates got the weights, but not quite entirely sure what those "weights" are, and how did they come up with such coefficients without actually solving the spline interpolation system. Would very much appreciate anybody's help!

shoyer commented 4 years ago

I'm glad you've found map_coordinates helpful for your work!

There were two reasons why I didn't implement higher order interpolation in map_coordinates:

  1. It's a lot more complicated. Linear interpolation is simple, but higher order splines require solving a sparse linear system (tridiagonal if I remember correctly) which would be a little tricky to do in JAX currently (we would probably need to add a custom_call into SciPy solvers from BLAS/LAPACK).
  2. The handling of boundary conditions in SciPy is arguably inconsistent (e.g., it has discontinuities even for higher order interpolation). It's not entirely clear to me how to implement the "correct" version. See https://github.com/scipy/scipy/issues/2640 and the various linked issues for discussion.
hawkinsp commented 4 years ago

@shoyer For (1) wouldn't an approach along the lines of https://github.com/google/jax/blob/8a8bb702d29c8036807d53b6110fc2fe2051559a/jax/image/scale.py#L55 work here? I had naively assumed that adding other interpolation methods to map_coordinates would mostly be a matter of changing the image resize code to sample at arbitrary points rather than an integer grid. But I don't know much about image processing...

LouisDesdoigts commented 3 years ago

@shoyer I was wondering what the status is on either implementing a higher-order map_coordinates interpolation or a some other higher-order interpolation method might be? Is this something that is currently in the works or is a low-priority issue? Cheers

shoyer commented 3 years ago

I don't think anybody has been working on this for JAX. Higher order interpolation would be very welcome, but the boundary handling in SciPy is broken, which makes it hard to know what to implement. I don't plan to work on this personally, but could probably review pull requests.

It does look like some progress is being slowly made on the SciPy side (https://github.com/scipy/scipy/issues/12773). If somebody wants to help, they could dig into that issue and the linked PRs to figure out what's changing for SciPy and at a high level hant a (correct) implementations would look like for JAX.

ksanjeevan commented 2 years ago

Wondering if there are any updates on this?

JamesAllingham commented 2 years ago

Just wanted to drop by and say that this is also a feature that I would find useful! Unfortunately, having had a quick look at the Scipy side I think it is a bit beyond me to work on this.

vboulanger commented 1 year ago

I would also be interested in this!

LouisDesdoigts commented 1 year ago

Also looks like the problem has been solved on the Scipy side!

shoyer commented 1 year ago

See also the related discussion in https://github.com/google/jax/issues/5687

stevenygd commented 1 year ago

I would relaly love this feature, any progress or help needed?

LouisDesdoigts commented 1 year ago

Being worked on slowly in the background at the moment, you can follow the PR linked above for progress. If you send me your email I can add you to the thread with some others looking to get this merged. Similarly you can build my Jax fork for a working version with different boundary conditions.