Open fischcheng opened 4 years ago
I'm glad you've found map_coordinates
helpful for your work!
There were two reasons why I didn't implement higher order interpolation in map_coordinates
:
@shoyer For (1) wouldn't an approach along the lines of https://github.com/google/jax/blob/8a8bb702d29c8036807d53b6110fc2fe2051559a/jax/image/scale.py#L55
work here? I had naively assumed that adding other interpolation methods to map_coordinates
would mostly be a matter of changing the image resize code to sample at arbitrary points rather than an integer grid. But I don't know much about image processing...
@shoyer I was wondering what the status is on either implementing a higher-order map_coordinates interpolation or a some other higher-order interpolation method might be? Is this something that is currently in the works or is a low-priority issue? Cheers
I don't think anybody has been working on this for JAX. Higher order interpolation would be very welcome, but the boundary handling in SciPy is broken, which makes it hard to know what to implement. I don't plan to work on this personally, but could probably review pull requests.
It does look like some progress is being slowly made on the SciPy side (https://github.com/scipy/scipy/issues/12773). If somebody wants to help, they could dig into that issue and the linked PRs to figure out what's changing for SciPy and at a high level hant a (correct) implementations would look like for JAX.
Wondering if there are any updates on this?
Just wanted to drop by and say that this is also a feature that I would find useful! Unfortunately, having had a quick look at the Scipy side I think it is a bit beyond me to work on this.
I would also be interested in this!
Also looks like the problem has been solved on the Scipy side!
See also the related discussion in https://github.com/google/jax/issues/5687
I would relaly love this feature, any progress or help needed?
Being worked on slowly in the background at the moment, you can follow the PR linked above for progress. If you send me your email I can add you to the thread with some others looking to get this merged. Similarly you can build my Jax fork for a working version with different boundary conditions.
Hi, I've been trying to use Jax to speed up a program that uses
scipy.ndimage.map_coordinates
extensively. Was able to use jax'smap_coordinates
to get a significant speedup, as well as usingjit
to speed up even more when looping many times. Pretty amazing.I am aware that the current
map_coordinates
in jax only supports some "modes" and up to order "1", but our program has been usingorder=3
, Is there a plan to implement such anytime soon? I am trying to implement such feature but haven't made much progress so far.https://github.com/scipy/scipy/blob/v1.5.2/scipy/ndimage/src/ni_splines.c
I went to the scipy.ndimage C-codes, and found the scipy.map_coordinates got the weights, but not quite entirely sure what those "weights" are, and how did they come up with such coefficients without actually solving the spline interpolation system. Would very much appreciate anybody's help!