Closed Nd-sole closed 6 years ago
@nainadhingra2012 the reason is that a numpy array is first indexed by rows then columns. Rows in an image correspond to the height which are the y
values, and columns correspond to the width which are the x
values.
So to answer your question, it's not in reverse order. Your 3D indices would be indices = tf.stack([b, y, x, z], 3)
assuming your image is in NHWC
format.
Hi @kevinzakka and @robotrory , Is it possible to check this extended stn for 3D. It is not working as expected for identity transform. I think there is possible problem in interpolation but I just extended as was given for 2D. Could you please look at it once? import tensorflow as tf from utils.print_utils import * import numpy as np
def spatial_transformer_network(input_fmap,theta, out_dims=None, **kwargs): """ Spatial Transformer Network layer implementation as described in [1].
The layer is composed of 3 elements:
- localisation_net: takes the original image as input and outputs
the parameters of the affine transformation that should be applied
to the input image.
- affine_grid_generator: generates a grid of (x,y) coordinates that
correspond to a set of points where the input should be sampled
to produce the transformed output.
- bilinear_sampler: takes as input the original image and the grid
and produces the output image using bilinear interpolation.
Input
-----
- input_fmap: output of the previous layer. Can be input if spatial
transformer layer is at the beginning of architecture. Should be
a tensor of shape (B, H, W, C).
- theta: affine transform tensor of shape (B, 6). Permits cropping,
translation and isotropic scaling. Initialize to identity matrix.
It is the output of the localization network.
Returns
-------
- out_fmap: transformed input feature map. Tensor of size (B, H, W, C).
Notes
-----
[1]: 'Spatial Transformer Networks', Jaderberg et. al,
(https://arxiv.org/abs/1506.02025)
# generate grids of same size or upsample/downsample if specified
if out_dims:
out_H = out_dims[0]
out_W = out_dims[1]
out_L = out_dims[2]
(batch_grids,batch_grids_I) = affine_grid_generator(out_H, out_W,out_L,theta)
else:
(batch_grids,batch_grids_I) = affine_grid_generator(H, W,L, theta)
x_s = batch_grids[:, 0, :, :]
y_s = batch_grids[:, 1, :, :]
z_s = batch_grids[:, 2, :, :]
# sample input with grid to get output
out_fmap= bilinear_sampler(input_fmap, x_s, y_s,z_s)
return out_fmap
def get_pixel_value(img, x, y,z): """ Utility function to get pixel value for coordinate vectors x and y from a 4D tensor image.
Input
-----
- img: tensor of shape (B, H, W, C)
- x: flattened tensor of shape (B*H*W, )
- y: flattened tensor of shape (B*H*W, )
Returns
-------
- output: tensor of shape (B, H, W, C)
"""
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]
length = shape[3]
batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1,1))
b = tf.tile(batch_idx, (1, height, width,length))
print("b",b.shape)
indices = tf.stack([b, y, x,z], 4)
return tf.gather_nd(img, indices)
def affine_grid_generator(height, width,length, theta): """ This function returns a sampling grid, which when used with the bilinear sampler on the input feature map, will create an output feature map that is an affine transformation [1] of the input feature map.
Input
-----
- height: desired height of grid/output. Used
to downsample or upsample.
- width: desired width of grid/output. Used
to downsample or upsample.
- theta: affine transform matrices of shape (num_batch, 2, 3).
For each image in the batch, we have 6 theta parameters of
the form (2x3) that define the affine transformation T.
Returns
-------
- normalized gird (-1, 1) of shape (num_batch, 2, H, W).
The 2nd dimension has 2 components: (x, y) which are the
sampling points of the original image for each point in the
target image.
Note
----
[1]: the affine transformation allows cropping, translation,
and isotropic scaling.
"""
# grab batch size
num_batch = tf.shape(theta)[0]
print("num_batch",num_batch)
# create normalized 2D grid
x = tf.linspace(-1.0, 1.0, width)
y = tf.linspace(-1.0, 1.0, height)
z = tf.linspace(-1.0, 1.0, length)
x_t, y_t,z_t = tf.meshgrid(x, y, z)
# flatten
x_t_flat = tf.reshape(x_t, [-1])
y_t_flat = tf.reshape(y_t, [-1])
z_t_flat = tf.reshape(z_t, [-1])
# reshape to [x_t, y_t , 1] - (homogeneous form)
ones = tf.ones_like(x_t_flat)
sampling_grid = tf.stack([x_t_flat, y_t_flat,z_t_flat, ones])
print("sampling_grid",sampling_grid.shape)
# repeat grid num_batch times
sampling_grid = tf.expand_dims(sampling_grid, axis=0)
print("sampling_grid",sampling_grid.shape)
sampling_grid = tf.tile(sampling_grid, tf.stack([num_batch, 1, 1]))
# cast to float32 (required for matmul)
theta = tf.cast(theta, 'float32')
sampling_grid = tf.cast(sampling_grid, 'float32')
# transform the sampling grid - batch multiply
batch_grids = tf.matmul(theta, sampling_grid)
# generating identity transform
theta_I = tf.constant([[1.0,0.0,0.0,0.0], [0.0, 1.0,0.0,0.0], [0.0, 0.0,1.0,0.0]])
theta_I=tf.expand_dims(theta_I,axis=0)
theta_I_batch=tf.tile(theta_I, [num_batch, 1, 1])
batch_grids_I=tf.matmul(theta_I_batch, sampling_grid) #new added
# batch grid has shape (num_batch, 3, H*W*L)
# reshape to (num_batch, H, W,3)
batch_grids = tf.reshape(batch_grids, [num_batch, 3, height, width,length])
batch_grids_I= tf.reshape(batch_grids_I, [num_batch, 3, height, width,length])
return (batch_grids,batch_grids_I)
def bilinear_sampler(img, x, y,z): """ Performs bilinear sampling of the input images according to the normalized coordinates provided by the sampling grid. Note that the sampling is done identically for each channel of the input.
To test if the function works properly, output image should be
identical to input image when theta is initialized to identity
transform.
Input
-----
- img: batch of images in (B, H, W, C) layout.
- grid: x, y which is the output of affine_grid_generator.
Returns
-------
- interpolated images according to grids. Same size as grid.
"""
B = tf.shape(img)[0]
H = tf.shape(img)[1]
W = tf.shape(img)[2]
L = tf.shape(img)[3]
C = tf.shape(img)[4]
max_y = tf.cast(H - 1, 'int32')
max_x = tf.cast(W - 1, 'int32')
max_z = tf.cast(L - 1, 'int32')
zero = tf.zeros([], dtype='int32')
# cast indices as float32 (for rescaling)
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
z = tf.cast(y, 'float32')
# rescale x and y to [0, W/H]
x = 0.5 * ((x + 1.0) * tf.cast(max_x , 'float32'))
y = 0.5 * ((y + 1.0) * tf.cast(max_y, 'float32'))
z = 0.5 * ((z + 1.0) * tf.cast(max_z, 'float32'))
# grab 4 nearest corner points for each (x_i, y_i)
# i.e. we need a rectangle around the point of interest
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
z0 = tf.cast(tf.floor(z), 'int32')
z1 = z0 + 1
# clip to range [0, H/W] to not violate img boundaries
x0 = tf.clip_by_value(x0, zero, max_x)
print("x0",x0.shape)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
z0 = tf.clip_by_value(z0, zero, max_z)
z1 = tf.clip_by_value(z1, zero, max_z)
# get pixel value at corner coords
Ia0 = get_pixel_value(img, x0, y0,z0)
Ia1 = get_pixel_value(img, x0, y0,z1)
Ib0 = get_pixel_value(img, x0, y1,z0)
Ib1 = get_pixel_value(img, x0, y1,z1)
Ic0 = get_pixel_value(img, x1, y0,z0)
Ic1 = get_pixel_value(img, x1, y0,z1)
Id0 = get_pixel_value(img, x1, y1,z0)
Id1 = get_pixel_value(img, x1, y1,z1)
# recast as float for delta calculation
x0 = tf.cast(x0, 'float32')
x1 = tf.cast(x1, 'float32')
y0 = tf.cast(y0, 'float32')
y1 = tf.cast(y1, 'float32')
z0 = tf.cast(z0, 'float32')
z1 = tf.cast(z1, 'float32')
# calculate deltas
wa0 = (x1-x) * (y1-y) * (z-z0)
wa1 = (x1-x) * (y1-y) * (z1-z)
wb0 = (x1-x) * (y-y0) * (z-z0)
wb1 = (x1-x) * (y-y0) * (z1-z)
wc0 = (x-x0) * (y1-y) * (z-z0)
wc1 = (x-x0) * (y1-y) * (z1-z)
wd0 = (x-x0) * (y-y0) * (z-z0)
wd1 = (x-x0) * (y-y0) * (z1-z)
# add dimension for addition
wa0 = tf.expand_dims(wa0, axis=4)
wa1 = tf.expand_dims(wa1, axis=4)
wb0 = tf.expand_dims(wb0, axis=4)
wb1 = tf.expand_dims(wb1, axis=4)
wc0 = tf.expand_dims(wc0, axis=4)
wc1 = tf.expand_dims(wc1, axis=4)
wd0 = tf.expand_dims(wd0, axis=4)
wd1 = tf.expand_dims(wd1, axis=4)
# compute output
out = tf.add_n([wa0*Ia1, wa1*Ia0, wb0*Ib1, wb1*Ib0, wc0*Ic1, wc1*Ic0, wd0*Id1, wd1*Id0])
return out
Hi, @nainadhingra2012 , I've complete the 3D version STN based on this project, pls see my this repository for more info :)
Hi @kevinzakka and @robotrory, Can you please tell me, def get_pixel_value(img, x, y): """ Utility function to get pixel value for coordinate vectors x and y from a 4D tensor image. Input
why you write indices = tf.stack([b, y, x], 3) in reverse order, shouldn't it be indices = tf.stack([b,x, y], 3). I am extending your method for 3D and I was wondering the reason of placing in reverse order. indices = tf.stack([b, z, y, x], 3), is this correct then?