Closed tvercaut closed 4 years ago
I think it may be helpful to look at the function you're optimizing:
import matplotlib.pyplot as plt
plt.figure()
x = jnp.linspace(-1, 2, num=91)
losses = jax.jit(jax.vmap(lossfunc))(jnp.stack([x, x], axis=1))
plt.plot(x, jax.device_get(losses), '-s')
plt.ylabel('loss')
plt.figure()
grads = jax.jit(jax.vmap(jax.grad(lossfunc)))(jnp.stack([x, x], axis=1))
plt.plot(x, jax.device_get(grads), '-s')
plt.ylabel('grad(loss)')
So yes, it's a little weird that the gradient is exactly zero at these points instead of picking the value from one of the sides (which might make more sense?) but you're going to have trouble optimizing this function with gradient based methods no matter how you calculate them, because the gradients on either side of this point go in opposite directions!
Maybe there's something different about your real use-case here?
Thanks @shoyer the only obvious difference with the real use case is that images are not random. A more complete example that includes the optimisation over the translation can be found here: https://colab.research.google.com/drive/1lkV7zBPL4YLiKwTxz1uL9K188uiF6BLI?usp=sharing
I guess the classical approximation of the gradient (-2(x1-x2)∇x2
) has a smoothing effect without which local minima at grid points are an issues. Swiching the interpolation order to 3 instead of 1 might help (at the cost of computational time) but for now this is not implemented in jax:
NotImplementedError: jax.scipy.ndimage.map_coordinates currently requires order<=1
I don't know the details from the signal processing literature, but I suspect adding some sort of a low-pass filter, either before or after resampling, is important to avoid anti-aliasing issues when doing this sort of alignment.
Many thanks @shoyer for spending time on this. Anti-aliasing is not typically needed in such a simple translation problem. I have expanded a bit the example (a fixed a small bug in the manual approximate gradient) to rely on a very smooth set of images:
The following graph shows the gradients along a diagonal translation as computed with the manual approximate gradient ((mx,my)
), the finite difference one ((ax,ay)
) and the jax one ((jx,jy)
with the spikes):
Code snippet:
OK, I agree this looks pretty bad! We shouldn't have the gradient deviate at a single point.
I bet the problem is when lower == upper
on these lines:
https://github.com/google/jax/blob/db71f3c5fc5226c3e9c87dd9f056d1b63cfa0286/jax/scipy/ndimage.py#L45-L53
Rather than computing upper = jnp.ceil(coordinate)
, we should probably just set upper = lower + 1
Sorry if this not a very minimal test case but let me explain my use case and issue as I believe the context will help.
I am trying to use jax for a toy image registration problem. Given two images
x1
andx2
I want to find the translationu
that minimises the difference betweenx1(.)
andx2(.+u)
as measured in terms of mean square error (MSE). The resampled version ofx2
after the (non-integer) translation is computed withmap_coordinates
.The computation of the gradient of the cost funcrtion in this context is usually done by assuming the images are continuous, computing the gradient of the MSE as
-2(x1-x2)∇x2
and computing∇x2
with something similar tonp.gradient
.Trying to mimick this setup with jax to avoid computing the gradient manually (this would be useful for example as soon as one wants to change the MSE loss for something else) fails to converge to a suitable translation (at least if initialised with an integer translation) as the gradient from jax is exactly zero for integer translations.
Below is a test case to illustrate the zero gradient issue. I understand there is a discontinuity of the gradient at such integer points but I was expecting to nonetheless get a proper sub-gradient.
I am not sure if I am doing something wrong or simply misunderstanding something but so far I haven't managed to get the image registration to converge with jax derivatives even though it does with a numerical approximation of the gradient, or the classical continous approximation I was refering to.
Outputs: