Closed JadBatmobile closed 5 years ago
I managed to get it working in my own segmentation network. I had to modify the call() function. Please note that i was feeding the inputs as channels first (this is my keras configuration), so transposing was not necessary (and therefore commented out). What was necessary to get the output dimension to be [None, c, h, w] was reading the unaries and rgb tensors as inputs[0], and inputs[1] (respectively), rather than inputs[0][0, :, :, :], and inputs[1][0, :, :, :]
def call(self, inputs):
#this is commented out from the original
#unaries = tf.transpose(inputs[0][0, :, :, :], perm=(2, 0, 1))
#rgb = tf.transpose(inputs[1][0, :, :, :], perm=(2, 0, 1))
#unaries = inputs[0][0, :, :, :]
#rgb = inputs[1][0, :, :, :]
unaries = inputs[0]
rgb = inputs[1]
#input is channels first
c, h, w = self.num_classes, self.image_dims[0], self.image_dims[1]
all_ones = np.ones((c, h, w), dtype=np.float32)
# Prepare filter normalization coefficients
spatial_norm_vals = custom_module.high_dim_filter(all_ones, rgb, bilateral=False,
bilateral_norm_vals = custom_module.high_dim_filter(all_ones, rgb, bilateral=True,
q_values = unaries
for i in range(self.num_iterations):
softmax_out = tf.nn.softmax(q_values, 0)
# Spatial filtering
spatial_out = custom_module.high_dim_filter(softmax_out, rgb, bilateral=False,
spatial_out = spatial_out / spatial_norm_vals
# Bilateral filtering
bilateral_out = custom_module.high_dim_filter(softmax_out, rgb, bilateral=True,
bilateral_out = bilateral_out / bilateral_norm_vals
# Weighting filter outputs
message_passing = (tf.matmul(self.spatial_ker_weights,
tf.reshape(spatial_out, (c, -1))) +
tf.reshape(bilateral_out, (c, -1))))
# Compatibility transform
pairwise = tf.matmul(self.compatibility_matrix, message_passing)
# Adding unary potentials
pairwise = tf.reshape(pairwise, (c, h, w))
q_values = unaries - pairwise
#output is channels first
#this is commented out from the original
#return tf.transpose(tf.reshape(q_values, (1, c, h, w)), perm=(0, 2, 3, 1))
#return tf.reshape(q_values, (1, c, h, w))
return q_values`
Thanks @JadBatmobile for posting the answer.
I am attempting to use the crf-rnn layer in my own network, I have encountered some problems , that are detailed in the following stackoverflow question. If somebody can help with this i would appreciate it immensely.