kevinzakka / spatial-transformer-network

A Tensorflow implementation of Spatial Transformer Networks.
MIT License
985 stars 268 forks source link

Regarding get_pixel_value in STN method #18

Closed Nd-sole closed 6 years ago

Nd-sole commented 6 years ago

Hi @kevinzakka and @robotrory, Can you please tell me, def get_pixel_value(img, x, y): """ Utility function to get pixel value for coordinate vectors x and y from a 4D tensor image. Input

- img: tensor of shape (B, H, W, C)
- x: flattened tensor of shape (B*H*W,)
- y: flattened tensor of shape (B*H*W,)
Returns
-------
- output: tensor of shape (B, H, W, C)
"""
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]

batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
b = tf.tile(batch_idx, (1, height, width))

indices = tf.stack([b, y, x], 3)

return tf.gather_nd(img, indices)

why you write indices = tf.stack([b, y, x], 3) in reverse order, shouldn't it be indices = tf.stack([b,x, y], 3). I am extending your method for 3D and I was wondering the reason of placing in reverse order. indices = tf.stack([b, z, y, x], 3), is this correct then?

kevinzakka commented 6 years ago

@nainadhingra2012 the reason is that a numpy array is first indexed by rows then columns. Rows in an image correspond to the height which are the y values, and columns correspond to the width which are the x values.

So to answer your question, it's not in reverse order. Your 3D indices would be indices = tf.stack([b, y, x, z], 3) assuming your image is in NHWC format.

Nd-sole commented 6 years ago

Hi @kevinzakka and @robotrory , Is it possible to check this extended stn for 3D. It is not working as expected for identity transform. I think there is possible problem in interpolation but I just extended as was given for 2D. Could you please look at it once? import tensorflow as tf from utils.print_utils import * import numpy as np

def spatial_transformer_network(input_fmap,theta, out_dims=None, **kwargs): """ Spatial Transformer Network layer implementation as described in [1].

The layer is composed of 3 elements:

- localisation_net: takes the original image as input and outputs 
  the parameters of the affine transformation that should be applied
  to the input image.

- affine_grid_generator: generates a grid of (x,y) coordinates that 
  correspond to a set of points where the input should be sampled 
  to produce the transformed output.

- bilinear_sampler: takes as input the original image and the grid
  and produces the output image using bilinear interpolation.

Input
-----
- input_fmap: output of the previous layer. Can be input if spatial
  transformer layer is at the beginning of architecture. Should be 
  a tensor of shape (B, H, W, C). 

- theta: affine transform tensor of shape (B, 6). Permits cropping, 
  translation and isotropic scaling. Initialize to identity matrix. 
  It is the output of the localization network.

Returns
-------
- out_fmap: transformed input feature map. Tensor of size (B, H, W, C).

Notes
-----
[1]: 'Spatial Transformer Networks', Jaderberg et. al,
     (https://arxiv.org/abs/1506.02025)

# generate grids of same size or upsample/downsample if specified
if out_dims:
    out_H = out_dims[0]
    out_W = out_dims[1]
    out_L = out_dims[2]
    (batch_grids,batch_grids_I) = affine_grid_generator(out_H, out_W,out_L,theta)
else:
    (batch_grids,batch_grids_I) = affine_grid_generator(H, W,L, theta)

x_s = batch_grids[:, 0, :, :]
y_s = batch_grids[:, 1, :, :]
z_s = batch_grids[:, 2, :, :]

# sample input with grid to get output
out_fmap= bilinear_sampler(input_fmap, x_s, y_s,z_s)
return out_fmap

def get_pixel_value(img, x, y,z): """ Utility function to get pixel value for coordinate vectors x and y from a 4D tensor image.

Input
-----
- img: tensor of shape (B, H, W, C)
- x: flattened tensor of shape (B*H*W, )
- y: flattened tensor of shape (B*H*W, )

Returns
-------
- output: tensor of shape (B, H, W, C)
"""
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]
length = shape[3]

batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1,1))
b = tf.tile(batch_idx, (1, height, width,length))
print("b",b.shape)
indices = tf.stack([b, y, x,z], 4)

return tf.gather_nd(img, indices)

def affine_grid_generator(height, width,length, theta): """ This function returns a sampling grid, which when used with the bilinear sampler on the input feature map, will create an output feature map that is an affine transformation [1] of the input feature map.

Input
-----
- height: desired height of grid/output. Used
  to downsample or upsample. 

- width: desired width of grid/output. Used
  to downsample or upsample. 

- theta: affine transform matrices of shape (num_batch, 2, 3). 
  For each image in the batch, we have 6 theta parameters of 
  the form (2x3) that define the affine transformation T.

Returns
-------
- normalized gird (-1, 1) of shape (num_batch, 2, H, W).
  The 2nd dimension has 2 components: (x, y) which are the 
  sampling points of the original image for each point in the
  target image.

Note
----
[1]: the affine transformation allows cropping, translation, 
     and isotropic scaling.
"""
# grab batch size
num_batch = tf.shape(theta)[0]
print("num_batch",num_batch)
# create normalized 2D grid
x = tf.linspace(-1.0, 1.0, width)
y = tf.linspace(-1.0, 1.0, height)
z = tf.linspace(-1.0, 1.0, length)
x_t, y_t,z_t = tf.meshgrid(x, y, z)

# flatten
x_t_flat = tf.reshape(x_t, [-1])
y_t_flat = tf.reshape(y_t, [-1])
z_t_flat = tf.reshape(z_t, [-1])

# reshape to [x_t, y_t , 1] - (homogeneous form)
ones = tf.ones_like(x_t_flat)
sampling_grid = tf.stack([x_t_flat, y_t_flat,z_t_flat, ones])
print("sampling_grid",sampling_grid.shape)
# repeat grid num_batch times
sampling_grid = tf.expand_dims(sampling_grid, axis=0)
print("sampling_grid",sampling_grid.shape)
sampling_grid = tf.tile(sampling_grid, tf.stack([num_batch, 1, 1]))

# cast to float32 (required for matmul)
theta = tf.cast(theta, 'float32')
sampling_grid = tf.cast(sampling_grid, 'float32')

# transform the sampling grid - batch multiply
batch_grids = tf.matmul(theta, sampling_grid)

# generating identity transform
theta_I = tf.constant([[1.0,0.0,0.0,0.0], [0.0, 1.0,0.0,0.0], [0.0, 0.0,1.0,0.0]])
theta_I=tf.expand_dims(theta_I,axis=0)

theta_I_batch=tf.tile(theta_I, [num_batch, 1, 1])

batch_grids_I=tf.matmul(theta_I_batch, sampling_grid)    #new added
# batch grid has shape (num_batch, 3, H*W*L)

# reshape to (num_batch, H, W,3)
batch_grids = tf.reshape(batch_grids, [num_batch, 3, height, width,length])
batch_grids_I= tf.reshape(batch_grids_I, [num_batch, 3, height, width,length])
return (batch_grids,batch_grids_I)

def bilinear_sampler(img, x, y,z): """ Performs bilinear sampling of the input images according to the normalized coordinates provided by the sampling grid. Note that the sampling is done identically for each channel of the input.

To test if the function works properly, output image should be
identical to input image when theta is initialized to identity
transform.

Input
-----
- img: batch of images in (B, H, W, C) layout.
- grid: x, y which is the output of affine_grid_generator.

Returns
-------
- interpolated images according to grids. Same size as grid.

"""
B = tf.shape(img)[0]
H = tf.shape(img)[1]
W = tf.shape(img)[2]
L = tf.shape(img)[3]
C = tf.shape(img)[4]

max_y = tf.cast(H - 1, 'int32')
max_x = tf.cast(W - 1, 'int32')
max_z = tf.cast(L - 1, 'int32')
zero = tf.zeros([], dtype='int32')

# cast indices as float32 (for rescaling)
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
z = tf.cast(y, 'float32')
# rescale x and y to [0, W/H]
x = 0.5 * ((x + 1.0) * tf.cast(max_x , 'float32'))
y = 0.5 * ((y + 1.0) * tf.cast(max_y, 'float32'))
z = 0.5 * ((z + 1.0) * tf.cast(max_z, 'float32'))

# grab 4 nearest corner points for each (x_i, y_i)
# i.e. we need a rectangle around the point of interest
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
z0 = tf.cast(tf.floor(z), 'int32')
z1 = z0 + 1

# clip to range [0, H/W] to not violate img boundaries
x0 = tf.clip_by_value(x0, zero, max_x)
print("x0",x0.shape)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
z0 = tf.clip_by_value(z0, zero, max_z)
z1 = tf.clip_by_value(z1, zero, max_z)

# get pixel value at corner coords
Ia0 = get_pixel_value(img, x0, y0,z0)
Ia1 = get_pixel_value(img, x0, y0,z1)
Ib0 = get_pixel_value(img, x0, y1,z0)
Ib1 = get_pixel_value(img, x0, y1,z1)
Ic0 = get_pixel_value(img, x1, y0,z0)
Ic1 = get_pixel_value(img, x1, y0,z1)
Id0 = get_pixel_value(img, x1, y1,z0)
Id1 = get_pixel_value(img, x1, y1,z1)

# recast as float for delta calculation
x0 = tf.cast(x0, 'float32')
x1 = tf.cast(x1, 'float32')
y0 = tf.cast(y0, 'float32')
y1 = tf.cast(y1, 'float32')
z0 = tf.cast(z0, 'float32')
z1 = tf.cast(z1, 'float32')

# calculate deltas
wa0 = (x1-x) * (y1-y) * (z-z0)
wa1 = (x1-x) * (y1-y) * (z1-z)

wb0 = (x1-x) * (y-y0) * (z-z0)
wb1 = (x1-x) * (y-y0) * (z1-z)

wc0 = (x-x0) * (y1-y) * (z-z0)
wc1 = (x-x0) * (y1-y) * (z1-z)

wd0 = (x-x0) * (y-y0) * (z-z0)
wd1 = (x-x0) * (y-y0) * (z1-z)

# add dimension for addition
wa0 = tf.expand_dims(wa0, axis=4)
wa1 = tf.expand_dims(wa1, axis=4)
wb0 = tf.expand_dims(wb0, axis=4)
wb1 = tf.expand_dims(wb1, axis=4)
wc0 = tf.expand_dims(wc0, axis=4)
wc1 = tf.expand_dims(wc1, axis=4)
wd0 = tf.expand_dims(wd0, axis=4)
wd1 = tf.expand_dims(wd1, axis=4)

# compute output
out = tf.add_n([wa0*Ia1, wa1*Ia0, wb0*Ib1, wb1*Ib0, wc0*Ic1, wc1*Ic0, wd0*Id1, wd1*Id0])

return out
HuntleyZ commented 5 years ago

Hi, @nainadhingra2012 , I've complete the 3D version STN based on this project, pls see my this repository for more info :)