I think there might be a problem with function _phase_shift(I,r). My code was crashing when image was resized from 1x1 to 2x2 . Probably it is because of some kind of dimension mismatch.
I changed it to :
def _phase_shift(I, r):
Helper function with main phase shift operation
bsize, a, b, c = I.get_shape().as_list()
X = tf.reshape(I, (bsize, a, b, r, r))
X = tf.transpose(X, (0, 1, 2, 4, 3)) # bsize, a, b, 1, 1
X = tf.split(1, a, X) # a, [bsize, b, r, r]
X = tf.concat(2, [tf.squeeze(x, squeeze_dims=1) for x in X]) # bsize, b, a*r, r
X = tf.split(1, b, X) # b, [bsize, a*r, r]
X = tf.concat(2, [tf.squeeze(x, squeeze_dims=1) for x in X]) #
bsize, a*r, b*r
return tf.reshape(X, (bsize, a*r, b*r, 1))
Its just adding squeeze_dims=1 . Now it seems to work with batch size = 1
I think there might be a problem with function _phase_shift(I,r). My code was crashing when image was resized from 1x1 to 2x2 . Probably it is because of some kind of dimension mismatch. I changed it to :
def _phase_shift(I, r):
Helper function with main phase shift operation
Its just adding squeeze_dims=1 . Now it seems to work with batch size = 1