sadeepj / crfasrnn_keras

CRF-RNN Keras/Tensorflow version
http://crfasrnn.torr.vision
MIT License
603 stars 169 forks source link

Using crfrnn_layer in differnt network #61

Closed JadBatmobile closed 5 years ago

JadBatmobile commented 5 years ago

I am attempting to use the crf-rnn layer in my own network, I have encountered some problems , that are detailed in the following stackoverflow question. If somebody can help with this i would appreciate it immensely.

https://stackoverflow.com/questions/55582630/dimenstions-of-custom-layer-output-keras

JadBatmobile commented 5 years ago

I managed to get it working in my own segmentation network. I had to modify the call() function. Please note that i was feeding the inputs as channels first (this is my keras configuration), so transposing was not necessary (and therefore commented out). What was necessary to get the output dimension to be [None, c, h, w] was reading the unaries and rgb tensors as inputs[0], and inputs[1] (respectively), rather than inputs[0][0, :, :, :], and inputs[1][0, :, :, :]

`

def call(self, inputs):

inputs are channel first

    #this is commented out from the original
    #unaries = tf.transpose(inputs[0][0, :, :, :], perm=(2, 0, 1))
    #rgb = tf.transpose(inputs[1][0, :, :, :], perm=(2, 0, 1))
    #unaries = inputs[0][0, :, :, :]
    #rgb = inputs[1][0, :, :, :]

    unaries = inputs[0]
    rgb = inputs[1]

    #input is channels first
    c, h, w = self.num_classes, self.image_dims[0], self.image_dims[1]

    all_ones = np.ones((c, h, w), dtype=np.float32)

    # Prepare filter normalization coefficients
    spatial_norm_vals = custom_module.high_dim_filter(all_ones, rgb, bilateral=False,
                                                      theta_gamma=self.theta_gamma)
    bilateral_norm_vals = custom_module.high_dim_filter(all_ones, rgb, bilateral=True,
                                                        theta_alpha=self.theta_alpha,
                                                        theta_beta=self.theta_beta)
    q_values = unaries

    for i in range(self.num_iterations):
        softmax_out = tf.nn.softmax(q_values, 0)

        # Spatial filtering
        spatial_out = custom_module.high_dim_filter(softmax_out, rgb, bilateral=False,
                                                    theta_gamma=self.theta_gamma)
        spatial_out = spatial_out / spatial_norm_vals

        # Bilateral filtering
        bilateral_out = custom_module.high_dim_filter(softmax_out, rgb, bilateral=True,
                                                      theta_alpha=self.theta_alpha,
                                                      theta_beta=self.theta_beta)
        bilateral_out = bilateral_out / bilateral_norm_vals

        # Weighting filter outputs
        message_passing = (tf.matmul(self.spatial_ker_weights,
                                     tf.reshape(spatial_out, (c, -1))) +
                           tf.matmul(self.bilateral_ker_weights,
                                     tf.reshape(bilateral_out, (c, -1))))

        # Compatibility transform
        pairwise = tf.matmul(self.compatibility_matrix, message_passing)

        # Adding unary potentials
        pairwise = tf.reshape(pairwise, (c, h, w))
        q_values = unaries - pairwise

    #output is channels first
    #this is commented out from the original
    #return tf.transpose(tf.reshape(q_values, (1, c, h, w)), perm=(0, 2, 3, 1))
    #return tf.reshape(q_values, (1, c, h, w))
    return q_values`
sadeepj commented 5 years ago

Thanks @JadBatmobile for posting the answer.