kwotsin / TensorFlow-ENet

TensorFlow implementation of ENet
MIT License
257 stars 123 forks source link

Using the network on 4-Channel Images #12

Closed anandcu3 closed 7 years ago

anandcu3 commented 7 years ago

I'm trying to use the network to train on 4-Channel Images. The changes in the code are

tf.image.decode_image(image, channels=4) in train_enet.py

and changed the number of channels in preprocessing.py. Now when I try to train the network I get the following error

tensorflow.python.framework.errors_impl.InvalidArgumentError: Cannot reshape a tensor with 734400 elements to shape [4,172800] (691200 elements) for 'ENet/unpool_1/Reshape_1' (op: 'Reshape') with input shapes: [4,1,90,120,17], [2] and with input tensors computed as partial shapes: input[1] = [4,172800].

Traceback (most recent call last):
  File "train_enet.py", line 337, in <module>
 run()
  File "train_enet.py", line 162, in run
    skip_connections=skip_connections)
  File "enet.py", line 476, in ENet
    pooling_indices=pooling_indices_1, output_shape=inputs_shape_1, scope=bottleneck_scope_name+'_0')
  File "Anaconda3\lib\site-packages\tensorflow\contrib\framework\python\ops\arg_scope.py", line 181, in func_with_args
    return func(*args, **current_args)
  File "enet.py", line 321, in bottleneck
    net_unpool = unpool(net_unpool, pooling_indices, output_shape=output_shape, scope='unpool')
  File "enet.py", line 108, in unpool
    indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size]))
  File "Anaconda3\lib\site-packages\tensorflow\python\ops\gen_array_ops.py", line 2451, in reshape
    name=name)
  File "Anaconda3\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 2508, in create_op
    set_shapes_for_outputs(ret)
  File "Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1873, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "Anaconda3\lib\site-packages\tensorflow\python\framework\ops.py", line 1823, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)
  File "Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 610, in call_cpp_shape_fn
    debug_python_shape_fn, require_shape_fn)
  File "Anaconda3\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 676, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Cannot reshape a tensor with 734400 elements to shape [4,172800] (691200 elements) for 'ENet/unpool_1/Reshape_1' (op: 'Reshape') with input shapes: [4,1,90,120,17], [2] and with input tensors computed as partial shapes: input[1] = [4,172800].

Do I need to make any other changes ? Will this network even work for 4-channel images?

kwotsin commented 7 years ago

This seems like a shape mismatch error, in which case you should try using get_shape() to check which shapes are not matching. I've not tried with 4-channel images yet so YMMV.

anandcu3 commented 7 years ago

I fixed this issue.. The problem was in the initial block. Changed the initial block to accept number of channels as input.. (Changes below in bold)

def initial_block(inputs,num_channels = 3, is_training=True, scope='initial_block'):
    '''
    The initial block for Enet has 2 branches: The convolution branch and Maxpool branch.

    The conv branch has (16 - num_channels) layers, while the maxpool branch gives "num_channels" layers corresponding to the image channels.
    Both output layers are then concatenated to give an output of 16 layers.

    NOTE: Does not need to store pooling indices since it won't be used later for the final upsampling.

    INPUTS:
    - inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels]

    OUTPUTS:
    - net_concatenated(Tensor): a 4D Tensor that contains the 
    '''
    #Convolutional branch
    net_conv = slim.conv2d(inputs, 16-num_channels, [3,3], stride=2, activation_fn=None, scope=scope+'_conv')
    net_conv = slim.batch_norm(net_conv, is_training=is_training, fused=True, scope=scope+'_batchnorm')
    net_conv = prelu(net_conv, scope=scope+'_prelu')
    #Max pool branch
    net_pool = slim.max_pool2d(inputs, [2,2], stride=2, scope=scope+'_max_pool')
    #Concatenated output - does it matter max pool comes first or conv comes first? probably not.
    net_concatenated = tf.concat([net_conv, net_pool], axis=3, name=scope+'_concat')
    return net_concatenated