tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 611 forks source link

Gradients of the resampler do not match finite differences on a integer-pixel grid. #2535

Open andrevitorelli opened 3 years ago

andrevitorelli commented 3 years ago

System information

Describe the bug tensorflow_addons.image.resampler gradients don't match numdifftools when using integer pixel warps, but do on non-integer pixel warps.

Code to reproduce the issue If we do:

import numpy as np
from matplotlib import pyplot as plt
from tensorflow_addons.image import resampler
from scipy.misc import face
import numdifftools
import tensorflow as tf

#get an image
image = face(gray=True)[-512:-512+128,-512:-512+128].astype('float32')
image_tf = tf.convert_to_tensor(image.reshape([1,128,128, 1]))

#set a warp
warp = np.stack(np.meshgrid(np.arange(128), np.arange(128)), axis=-1).astype('float32')
warp_tf = tf.convert_to_tensor(warp.reshape([1,128,128,2]))

#define a shift
shift = tf.zeros([1,2])

#calculate derivatives via tf.GradientTape
with tf.GradientTape() as tape:
    tape.watch(shift)
    ws = tf.reshape(shift,[1,1,1,2]) + warp_tf
    o = resampler(image_tf, ws)
autodiff_jacobian = tape.batch_jacobian(o, shift) 

#calculate derivatives via numdifftools
def fn(shift):
    shift = tf.convert_to_tensor(shift.astype('float32'))
    ws = tf.reshape(shift,[1,1,1,2]) + warp_tf
    o = resampler(image_tf, ws)
    return o.numpy().flatten()

numdiff_jacobian = numdifftools.Jacobian(fn, order=4, step=0.04)
numdiff_jacobian = numdiff_jacobian(np.zeros([2])).reshape([128,128,2])

#display residuals
plt.figure(figsize=(15,5))
plt.subplot(121)
residual1 = abs(autodiff_jacobian[0,:,:,0,0] - numdiff_jacobian[:,:,0])
plt.imshow(residual1[2:-2,2:-2]) ; plt.colorbar()
plt.subplot(122)
residual2 = abs(autodiff_jacobian[0,:,:,0,1] - numdiff_jacobian[:,:,1])
plt.imshow(residual2[2:-2,2:-2]) ; plt.colorbar()

We see large residuals (on the same order of the pixel values):

integer_pixels

But if we do

#set a warp
warp = np.stack(np.meshgrid(np.arange(128), np.arange(128)), axis=-1).astype('float32')
warp_tf = tf.convert_to_tensor(warp.reshape([1,128,128,2])+.5) #add a half-step 

#define a shift
shift = tf.zeros([1,2])

#calculate derivatives via tf.GradientTape
with tf.GradientTape() as tape:
    tape.watch(shift)
    ws = tf.reshape(shift,[1,1,1,2]) + warp_tf
    o = resampler(image_tf, ws)
autodiff_jacobian = tape.batch_jacobian(o, shift) 

#calculate derivatives via numdifftools
def fn(shift):
    shift = tf.convert_to_tensor(shift.astype('float32'))
    ws = tf.reshape(shift,[1,1,1,2]) + warp_tf
    o = resampler(image_tf, ws)
    return o.numpy().flatten()

numdiff_jacobian = numdifftools.Jacobian(fn, order=4, step=0.04)
numdiff_jacobian = numdiff_jacobian(np.zeros([2])).reshape([128,128,2])

#display residuals
plt.figure(figsize=(15,5))
plt.subplot(121)
residual1 = abs(autodiff_jacobian[0,:,:,0,0] - numdiff_jacobian[:,:,0])
plt.imshow(residual1[2:-2,2:-2]) ; plt.colorbar()
plt.subplot(122)
residual2 = abs(autodiff_jacobian[0,:,:,0,1] - numdiff_jacobian[:,:,1])
plt.imshow(residual2[2:-2,2:-2]) ; plt.colorbar()

The residuals are now around 3 magnitudes less: half_pixels

Other info / logs We (@EiffL, @DR-Zero, and me) are currently working on a re-implementation of the resampler kernels to have better interpolators available, to use in our project autometacal, and its requirement GalFlow - so, there are chances that these will change. Any insights here would be helpful, thank you.

bhack commented 3 years ago

I don't think that the codeowner of this is still available /cc @autoih. If you want to send a PR you are welcome.

andrevitorelli commented 3 years ago

We will try to do so asap, thanks!