jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

Zero gradient when resampling an image at grid location using map_coordinates #3024

Closed tvercaut closed 4 years ago

tvercaut commented 4 years ago

Sorry if this not a very minimal test case but let me explain my use case and issue as I believe the context will help.

I am trying to use jax for a toy image registration problem. Given two images x1 and x2 I want to find the translation u that minimises the difference between x1(.) and x2(.+u) as measured in terms of mean square error (MSE). The resampled version of x2 after the (non-integer) translation is computed with map_coordinates.

The computation of the gradient of the cost funcrtion in this context is usually done by assuming the images are continuous, computing the gradient of the MSE as -2(x1-x2)∇x2 and computing ∇x2 with something similar to np.gradient.

Trying to mimick this setup with jax to avoid computing the gradient manually (this would be useful for example as soon as one wants to change the MSE loss for something else) fails to converge to a suitable translation (at least if initialised with an integer translation) as the gradient from jax is exactly zero for integer translations.

Below is a test case to illustrate the zero gradient issue. I understand there is a discontinuity of the gradient at such integer points but I was expecting to nonetheless get a proper sub-gradient.

I am not sure if I am doing something wrong or simply misunderstanding something but so far I haven't managed to get the image registration to converge with jax derivatives even though it does with a numerical approximation of the gradient, or the classical continous approximation I was refering to.

import jax
import jax.numpy as jnp
import numpy as onp
from jax.scipy import ndimage as jndimage
from scipy import ndimage as ondimage
import scipy as oscipy

# This needs to run at startup
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision
jax.config.update('jax_enable_x64', True)

# Exclude a border in hte computation of the loss to try and avoid numerical issues
excl_border = 2

def run(np,ndimage):
  print(f"\nRunning on {np}")
  onp.random.seed(0)
  x1 = onp.random.randn(20,10)
  x2 = onp.random.randn(20,10)

  grid_x, grid_y = np.meshgrid(np.arange(x1.shape[1]), np.arange(x1.shape[0]))

  def lossfunc(du):
    # Get translated grid
    def_grid_x = grid_x + du[0]
    def_grid_y = grid_y + du[1]

    # Resample image
    tmpx2_warped = jndimage.map_coordinates(x2, [def_grid_y, def_grid_x], order=1)

    # Compute the MSE between the warped image an the fixed image
    diff_im = tmpx2_warped[excl_border:-excl_border,excl_border:-excl_border]-x1[excl_border:-excl_border,excl_border:-excl_border]

    imloss = np.mean((diff_im)**2)

    return imloss

  def mg_lossfunc(du):
    #Get translated grid
    def_grid_x = grid_x + du[0]
    def_grid_y = grid_y + du[1]

    # Resample image
    tmpx2_warped = ndimage.map_coordinates(x2, [def_grid_y, def_grid_x], order=1)

    # Compute the MSE between the warped image an the fixed image
    diff_im = tmpx2_warped[excl_border:-excl_border,excl_border:-excl_border]-x1[excl_border:-excl_border,excl_border:-excl_border]
    jm = np.gradient(tmpx2_warped)
    jmrx = jm[0][excl_border:-excl_border,excl_border:-excl_border]
    jmry = jm[1][excl_border:-excl_border,excl_border:-excl_border]
    jdmx = -2.*diff_im*jmrx
    jdmy = -2.*diff_im*jmry

    return np.array([np.mean(jdmx), np.mean(jdmy)])

  print(f"loss at 0,0: {lossfunc([0., 0.])}")
  print(f"loss at 0.1,0.1: {lossfunc([0.1, 0.1])}")

  print(f"Manual approx gradient at 0,0: {mg_lossfunc([0., 0.])}")
  print(f"Manual approx gradient at 0.1,0.1: {mg_lossfunc([0.1, 0.1])}")
  # Finite difference gradient with a large step as the image is sampled on a grid and interpolated
  epsgrad = 0.1
  print(f"Numerical approx gradient at 0,0: {oscipy.optimize.approx_fprime([0., 0.],lossfunc,epsgrad)}")
  print(f"Numerical approx gradient at 0.1,0.1: {oscipy.optimize.approx_fprime([0.1, 0.1],lossfunc,epsgrad)}")
  if np==jnp:
    jg_lossfunc = lambda du:np.asarray(jax.jit(jax.grad(lossfunc))(du))
    print(f"Jax gradient at 0,0: {jg_lossfunc([0., 0.])}")
    print(f"Jax gradient at 0.1,0.1: {jg_lossfunc([0.1, 0.1])}")

run(onp,ondimage)
run(jnp,jndimage)

Outputs:

Running on <module 'numpy' from '/usr/local/lib/python3.6/dist-packages/numpy/__init__.py'>
loss at 0,0: 1.9514397057693333
loss at 0.1,0.1: 1.6584230856711766
Manual approx gradient at 0,0: [ 0.01189463 -0.21033731]
Manual approx gradient at 0.1,0.1: [-0.00626396 -0.12798883]
Numerical approx gradient at 0,0: [-1.72714353 -1.47912221]
Numerical approx gradient at 0.1,0.1: [-1.15234194 -0.9063109 ]

Running on <module 'jax.numpy' from '/usr/local/lib/python3.6/dist-packages/jax/numpy/__init__.py'>
loss at 0,0: 1.9514397057693333
loss at 0.1,0.1: 1.6584230856711766
Manual approx gradient at 0,0: [ 0.01189463 -0.21033731]
Manual approx gradient at 0.1,0.1: [-0.00626396 -0.12798883]
Numerical approx gradient at 0,0: [-1.72714353 -1.47912221]
Numerical approx gradient at 0.1,0.1: [-1.15234194 -0.9063109 ]
Jax gradient at 0,0: [0. 0.]
Jax gradient at 0.1,0.1: [-1.30169297 -1.05466678]
shoyer commented 4 years ago

I think it may be helpful to look at the function you're optimizing:

import matplotlib.pyplot as plt

plt.figure()
x = jnp.linspace(-1, 2, num=91)
losses = jax.jit(jax.vmap(lossfunc))(jnp.stack([x, x], axis=1))
plt.plot(x, jax.device_get(losses), '-s')
plt.ylabel('loss')

plt.figure()
grads = jax.jit(jax.vmap(jax.grad(lossfunc)))(jnp.stack([x, x], axis=1))
plt.plot(x, jax.device_get(grads), '-s')
plt.ylabel('grad(loss)')

image image

So yes, it's a little weird that the gradient is exactly zero at these points instead of picking the value from one of the sides (which might make more sense?) but you're going to have trouble optimizing this function with gradient based methods no matter how you calculate them, because the gradients on either side of this point go in opposite directions!

Maybe there's something different about your real use-case here?

tvercaut commented 4 years ago

Thanks @shoyer the only obvious difference with the real use case is that images are not random. A more complete example that includes the optimisation over the translation can be found here: https://colab.research.google.com/drive/1lkV7zBPL4YLiKwTxz1uL9K188uiF6BLI?usp=sharing

I guess the classical approximation of the gradient (-2(x1-x2)∇x2) has a smoothing effect without which local minima at grid points are an issues. Swiching the interpolation order to 3 instead of 1 might help (at the cost of computational time) but for now this is not implemented in jax:

NotImplementedError: jax.scipy.ndimage.map_coordinates currently requires order<=1
shoyer commented 4 years ago

I don't know the details from the signal processing literature, but I suspect adding some sort of a low-pass filter, either before or after resampling, is important to avoid anti-aliasing issues when doing this sort of alignment.

tvercaut commented 4 years ago

Many thanks @shoyer for spending time on this. Anti-aliasing is not typically needed in such a simple translation problem. I have expanded a bit the example (a fixed a small bug in the manual approximate gradient) to rely on a very smooth set of images: image

The following graph shows the gradients along a diagonal translation as computed with the manual approximate gradient ((mx,my)), the finite difference one ((ax,ay)) and the jax one ((jx,jy) with the spikes): image

Code snippet:

Click to expand code ```python import jax import jax.numpy as jnp import numpy as onp from jax.scipy import ndimage as jndimage from scipy import ndimage as ondimage import scipy as oscipy import matplotlib.pyplot as plt # This needs to run at startup # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision jax.config.update('jax_enable_x64', True) # Exclude a border in the computation of the loss to try and avoid numerical issues excl_border = 2 def run(np,ndimage): print(f"\n===\nRunning on {np}") h=30 w=40 grid_x, grid_y = np.meshgrid(np.arange(w), np.arange(h)) cx1 = np.round(w/2.) cy1 = np.round(h/2.) sx = w/4. sy = h/4. cx2 = cx1+2. cy2 = cy1+2. x1 = 10*np.exp( -( ((grid_x-cx1)/sx)**2 + ((grid_y-cy1)/sy)**2 ) ) x2 = 10*np.exp( -( ((grid_x-cx2)/sx)**2 + ((grid_y-cy2)/sy)**2 ) ) plt.figure() fig, axs = plt.subplots(1,2) axs[0].set_title('x1') axs[0].imshow(x1) axs[1].set_title('x2') axs[1].imshow(x2) plt.show() def lossfunc(du): # Get translated grid def_grid_x = grid_x + du[0] def_grid_y = grid_y + du[1] # Resample image tmpx2_warped = jndimage.map_coordinates(x2, [def_grid_y, def_grid_x], order=1) # Compute the MSE between the warped image an the fixed image diff_im = tmpx2_warped[excl_border:-excl_border,excl_border:-excl_border]-x1[excl_border:-excl_border,excl_border:-excl_border] imloss = np.mean((diff_im)**2) return imloss def mg_lossfunc(du): #Get translated grid def_grid_x = grid_x + du[0] def_grid_y = grid_y + du[1] # Resample image tmpx2_warped = ndimage.map_coordinates(x2, [def_grid_y, def_grid_x], order=1) # Compute the MSE between the warped image an the fixed image diff_im = tmpx2_warped[excl_border:-excl_border,excl_border:-excl_border]-x1[excl_border:-excl_border,excl_border:-excl_border] jm = np.gradient(tmpx2_warped) jmrx = jm[1][excl_border:-excl_border,excl_border:-excl_border] jmry = jm[0][excl_border:-excl_border,excl_border:-excl_border] jdmx = 2.*diff_im*jmrx jdmy = 2.*diff_im*jmry return np.array([np.mean(jdmx), np.mean(jdmy)]) print(f"loss at 0,0: {lossfunc([0., 0.])}") print(f"loss at 0.1,0.1: {lossfunc([0.1, 0.1])}") print(f"loss at -0.1,-0.1: {lossfunc([-0.1, -0.1])}") plt.figure(figsize=(30, 2)) uu = np.arange(-1.1, 3.5, 0.05) losses = list(map(lossfunc, np.stack([uu, uu], axis=1))) plt.plot(uu, losses, '-s') plt.ylabel('loss') plt.show() print(f"Manual approx gradient at 0,0: {mg_lossfunc([0., 0.])}") print(f"Manual approx gradient at 0.1,0.1: {mg_lossfunc([0.1, 0.1])}") print(f"Manual approx gradient at -0.1,-0.1: {mg_lossfunc([-0.1, -0.1])}") print(f"Manual approx gradient at 2,-1: {mg_lossfunc([2., -1.])}") print(f"Manual approx gradient at 2.1,-1.1: {mg_lossfunc([2.1, -1.1])}") # Finite difference gradient with a large step as the image is sampled on a grid and interpolated epsgrad = 0.1 ag_lossfunc = lambda du: oscipy.optimize.approx_fprime(du,lossfunc,epsgrad) print(f"Numerical approx gradient at 0,0: {ag_lossfunc([0., 0.])}") print(f"Numerical approx gradient at 0.1,0.1: {ag_lossfunc([0.1, 0.1])}") print(f"Numerical approx gradient at -0.1,-0.1: {ag_lossfunc([-0.1, -0.1])}") print(f"Numerical approx gradient at 2.,-1.: {ag_lossfunc([2., -1.])}") print(f"Numerical approx gradient at 2.1,-1.1: {ag_lossfunc([2.1, -1.1])}") if np==jnp: jg_lossfunc = lambda du:np.asarray(jax.jit(jax.grad(lossfunc))(du)) print(f"Jax gradient at 0,0: {jg_lossfunc([0., 0.])}") print(f"Jax gradient at 0.1,0.1: {jg_lossfunc([0.1, 0.1])}") print(f"Jax gradient at -0.1,-0.1: {jg_lossfunc([-0.1, -0.1])}") print(f"Jax gradient at 2.1,-1.1: {jg_lossfunc([2., -1.])}") print(f"Jax gradient at 2.1,-1.1: {jg_lossfunc([2.1, -1.1])}") plt.figure(figsize=(30, 2)) mg_grads = list(map(mg_lossfunc, np.stack([uu, uu], axis=1))) plt.plot(uu, mg_grads, '-o') ag_grads = list(map(ag_lossfunc, np.stack([uu, uu], axis=1))) plt.plot(uu, ag_grads, '-+') plt.legend(['mx', 'my', 'ax', 'ay']) if np==jnp: jg_lossfunc = lambda du:np.asarray(jax.jit(jax.grad(lossfunc))(du)) jg_grads = list(map(jg_lossfunc, np.stack([uu, uu], axis=1))) plt.plot(uu, jg_grads, '-d') plt.legend(['mx', 'my', 'ax', 'ay', 'jx', 'jy']) plt.ylabel('grad(loss)') plt.show() print("\nBFGS optimisation\n") u = np.zeros(2) opt_opt={'disp': True, 'maxiter': 200, 'eps': epsgrad, 'gtol':1e-9} res = oscipy.optimize.minimize(lossfunc, u, method="BFGS", options=opt_opt) print(f"\nNumerical approx gradient - loss at optim end {res.x}: {lossfunc(res.x)}") res = oscipy.optimize.minimize(lossfunc, u, jac=mg_lossfunc, method="BFGS", options=opt_opt) print(f"\nManual approx gradient - loss at optim end {res.x}: {lossfunc(res.x)}") if np==jnp: jg_lossfunc = lambda du:np.asarray(jax.jit(jax.grad(lossfunc))(du)) res = oscipy.optimize.minimize(lossfunc, u, jac=jg_lossfunc, method="BFGS", options=opt_opt) print(f"\nJax grad - loss at optim end {res.x}: {lossfunc(res.x)}") run(onp,ondimage) run(jnp,jndimage) ```
Click to expand output ``` === Running on
loss at 0,0: 1.3439400676144586 loss at 0.1,0.1: 1.214026608789585 loss at -0.1,-0.1: 1.4720960079911962 Manual approx gradient at 0,0: [-0.4687936 -0.82254061] Manual approx gradient at 0.1,0.1: [-0.44707712 -0.78433099] Manual approx gradient at -0.1,-0.1: [-0.48807353 -0.85495853] Manual approx gradient at 2,-1: [-7.75443519e-05 -1.18900303e+00] Manual approx gradient at 2.1,-1.1: [ 0.02652188 -1.21587652] Numerical approx gradient at 0,0: [-0.46826943 -0.83002476] Numerical approx gradient at 0.1,0.1: [-0.44431941 -0.78719149] Numerical approx gradient at -0.1,-0.1: [-0.46686583 -0.81231689] Numerical approx gradient at 2.,-1.: [ 0.00328248 -1.22299187] Numerical approx gradient at 2.1,-1.1: [ 0.02816145 -1.15391493] BFGS optimisation Warning: Desired error not necessarily achieved due to precision loss. Current function value: 0.000708 Iterations: 4 Function evaluations: 324 Gradient evaluations: 78 Numerical approx gradient - loss at optim end [1.95190614 1.95647443]: 0.0007083630658263352 Warning: Desired error not necessarily achieved due to precision loss. Current function value: 0.000000 Iterations: 8 Function evaluations: 50 Gradient evaluations: 39 Manual approx gradient - loss at optim end [1.99890299 1.99994993]: 1.507664203680706e-07 === Running on
loss at 0,0: 1.3439400676144586 loss at 0.1,0.1: 1.214026608789585 loss at -0.1,-0.1: 1.4720960079911962 Manual approx gradient at 0,0: [-0.4687936 -0.82254061] Manual approx gradient at 0.1,0.1: [-0.44707712 -0.78433099] Manual approx gradient at -0.1,-0.1: [-0.48807353 -0.85495853] Manual approx gradient at 2,-1: [-7.75443519e-05 -1.18900303e+00] Manual approx gradient at 2.1,-1.1: [ 0.02296511 -1.21586866] Numerical approx gradient at 0,0: [-0.46826943 -0.83002476] Numerical approx gradient at 0.1,0.1: [-0.44431941 -0.78719149] Numerical approx gradient at -0.1,-0.1: [-0.46686583 -0.81231689] Numerical approx gradient at 2.,-1.: [ 0.00328248 -1.22299187] Numerical approx gradient at 2.1,-1.1: [ 0.02816145 -1.15391493] Jax gradient at 0,0: [0. 0.] Jax gradient at 0.1,0.1: [-0.45671462 -0.80902832] Jax gradient at -0.1,-0.1: [-0.47918073 -0.83378842] Jax gradient at 2.1,-1.1: [0. 0.] Jax gradient at 2.1,-1.1: [ 0.01547956 -1.17478728] BFGS optimisation Warning: Desired error not necessarily achieved due to precision loss. Current function value: 0.000708 Iterations: 4 Function evaluations: 324 Gradient evaluations: 78 Numerical approx gradient - loss at optim end [1.95190614 1.95647443]: 0.0007083630658263352 Optimization terminated successfully. Current function value: 0.000000 Iterations: 9 Function evaluations: 10 Gradient evaluations: 10 Manual approx gradient - loss at optim end [2. 2.]: 5.814722077624332e-20 Optimization terminated successfully. Current function value: 1.343940 Iterations: 0 Function evaluations: 1 Gradient evaluations: 1 Jax grad - loss at optim end [0. 0.]: 1.3439400676144586 ```
shoyer commented 4 years ago

OK, I agree this looks pretty bad! We shouldn't have the gradient deviate at a single point.

shoyer commented 4 years ago

I bet the problem is when lower == upper on these lines: https://github.com/google/jax/blob/db71f3c5fc5226c3e9c87dd9f056d1b63cfa0286/jax/scipy/ndimage.py#L45-L53

Rather than computing upper = jnp.ceil(coordinate), we should probably just set upper = lower + 1