Open JayBad opened 7 years ago
Maybe you'll find https://arxiv.org/abs/1701.06659 interesting.
Thanks for the reference it indeed seems to overcome the problem of SSD for small objects. I will take a look.
The link is broken. Can you provide the link please.
My bad, here it is again. Any success with detecting small objects?
I did not test it yet and I really do not know when I will have the time to do it but I will keep you posted if I try it.
@JayBad thanks! I'd appreciate that. I'll do the same.
did you find the the DSSD on the github, recently, I want to use the SSD model to detect the small and blurry object.
I haven't found any implementations of the DSSD just yet. But, by reading the paper, it seems quite straight forward. I added the extra layers myself and I'm training them right now. I'll let you know if I've had any success.
@natlachaman thank you for you message, I'm looking forward to hearing from you soon.
Hi Natlachman,
Can you share your DSSD implement on Github?
Thanks.
@humayun Super sorry for the delay. For some reason I overlooked your message. About my implementation, I can't share the entire Github repository with you since I'm working under a non-disclosure agreement for a local business in my city. But I can share how I define the model and the DSSD module here:
def ssd_net(inputs,
num_classes,
feat_layers,
anchor_sizes,
anchor_ratios,
normalizations,
weights_path='./'):
def conv2d_block(net, key, block, nlayers, nkernels, convstr, spool, poolstr):
for l in range(1, nlayers+1):
# Batch norm
bn_name = 'bn{}_{}'.format(block, l)
if l < 2:
net[bn_name] = BatchNormalization()(net[key])
else:
net[bn_name] = BatchNormalization()(net[conv_name])
# Relu activation
relu_name = 'relu{}_{}'.format(block, l)
net[relu_name] = Activation('relu')(net[bn_name])
# Conv2D layer
conv_name = 'conv{}_{}'.format(block, l)
net[conv_name] = Conv2D(nkernels[l-1],
(3, 3),
padding='same',
strides=convstr[l-1],
name=conv_name)(net[relu_name])
# MaxPool 2D layer at the end of block
max_name = 'max{}'.format(block)
net[max_name] = MaxPooling2D(spool, strides=poolstr, padding='same', name=max_name)(net[conv_name])
return max_name
def conv2d_block_padding(net, key, block, skernels, nkernels, convstr):
net['conv{}_1'.format(block)] = Conv2D(nkernels[0],
skernels[0],
padding='same',
name='conv{}_1'.format(block))(net[key])
net['bn{}_1'.format(block)] = BatchNormalization()(net['conv{}_1'.format(block)])
net['pad{}'.format(block)] = ZeroPadding2D((1, 1))(net['bn{}_1'.format(block)])
net['conv{}_2'.format(block)] = Conv2D(nkernels[1],
skernels[1],
padding='valid',
strides=convstr,
name='conv{}_2'.format(block))(net['pad{}'.format(block)])
net['bn{}_2'.format(block)] = BatchNormalization()(net['conv{}_2'.format(block)])
return 'bn{}_2'.format(block)
net = dict()
net['input'] = Input(shape=inputs)
last = conv2d_block(net, 'input', 1, 2, nkernels=[64]*2, convstr=[(1, 1)]*2, spool=(2, 2), poolstr=(2, 2))
last = conv2d_block(net, last, 2, 2, nkernels=[128]*2, convstr=[(1, 1)]*2, spool=(2, 2), poolstr=(2, 2))
last = conv2d_block(net, last, 3, 3, nkernels=[256]*3, convstr=[(1, 1)]*3, spool=(2, 2), poolstr=(2, 2))
last = conv2d_block(net, last, 4, 3, nkernels=[512]*3, convstr=[(1, 1)]*3, spool=(2, 2), poolstr=(2, 2))
last = conv2d_block(net, last, 5, 3, nkernels=[512]*3, convstr=[(1, 1)]*3, spool=(2, 2), poolstr=(2, 2))
net['conv6'] = Conv2D(1024, (1, 1), padding='same', name='conv6')(last)
net['conv7'] = Conv2D(1024, (1, 1), padding='same', name='conv7')(net['conv6'])
last = conv2d_block_padding(net, 'conv7', 8, skernels=[(1, 1), (3, 3)], nkernels=(256, 512), convstr=(2, 2))
last = conv2d_block_padding(net, last, 9, skernels=[(1, 1), (3, 3)], nkernels=(128, 256), convstr=(2, 2))
last = conv2d_block_padding(net, last, 10, skernels=[(1, 1), (3, 3)], nkernels=(128, 256), convstr=(2, 2))
last = conv2d_block_padding(net, last, 11, skernels=[(1, 1), (3, 3)], nkernels=(128, 256), convstr=(2, 2))
_ = conv2d_block_padding(net, last, 12, skernels=[(1, 1), (4, 4)], nkernels=(128, 256), convstr=(1, 1))
# DSSD layers
dssd_layers = []
for i, layer in enumerate(reversed(feat_layers)):
if i < 1:
dssd_layers.append(layer)
else:
deconvolutional_module(net, dssd_layers[-1], layer)
dssd_layers.append('dssd_{}'.format(layer))
# Prediction and localisations layers.
predictions = []
logits = []
localisations = []
for i, layer in enumerate(dssd_layers):
with tf.variable_scope(layer + '_box'):
p, l = ssd_multibox_layer(net,
layer,
num_classes,
anchor_sizes[i],
anchor_ratios[i],
normalizations[i])
predictions.append(Activation('softmax')(p))
logits.append(p)
localisations.append(l)
# concatenate all box proposals and its class predictions
net['logits'] = Concatenate(axis=1,
name='conf')(logits)
net['localisations'] = Concatenate(axis=1,
name='loc')(localisations)
net['predictions'] = Concatenate(axis=1,
name='pred')(predictions)
net['output'] = Concatenate(axis=2,
name='output')([net['predictions'],
net['localisations'],
net['logits']])
model = Model(inputs=net['input'],
outputs=[net['logits'], net['localisations']], name='ssd_multibox')
model.summary()
return model
And the SSDMutlibox and DSSD modules are defined as:
def deconvolutional_module(net, deconv, ssd_layer):
# deconvolution path
x = Conv2DTranspose(512, [2, 2])(deconv)
x = Conv2D(512, [3, 3], padding='same')(x)
x = BatchNormalization()(x)
# ssd layer path
y = Conv2D(512, [3, 3], padding='same')(net[ssd_layer])
y = BatchNormalization()(y)
y = Activation('Relu')(y)
y = Conv2D(512, [3, 3], padding='same')(y)
y = BatchNormalization()(y)
# merge and output
z = Multiply()([x, y])
net['dssd_{}'.format(ssd_layer)] = Activation('Relu')(z)
def ssd_multibox_layer(net,
layer,
num_classes,
sizes,
ratios=[1],
normalization=-1):
"""Construct a multibox layer, return a class and localization predictions.
"""
base_layer = net[layer]
# added sequence from DSSD
x = Conv2D(256, [1, 1], padding='same')(base_layer)
x = Conv2D(256, [1, 1], padding='same')(x)
x = Conv2D(1024, [1, 1], padding='same')(x)
y = Conv2D(1024, [1, 1], padding='same')(base_layer)
base_layer = Add()([x, y])
# Spatial L2 norm (for exploiting gradients)
if normalization > 0:
base_layer = Normalize(20, name='norm_{}'.format(layer))(base_layer)
num_anchors = len(sizes) + len(ratios)
# Location.
loc_pred = Conv2D(num_anchors * 4, [3, 3], padding='same', name='loc_{}'.format(layer))(base_layer)
shape = tfe.tensors.get_shape(loc_pred, 4)
loc_pred = Reshape((shape[1]*shape[2]*num_anchors, 4))(loc_pred)
# Class prediction.
cls_pred = Conv2D(num_anchors * num_classes, [3, 3], padding='same', name='cls_{}'.format(layer))(base_layer)
shape = tfe.tensors.get_shape(cls_pred, 4)
cls_pred = Reshape((shape[1]*shape[2]*num_anchors, num_classes))(cls_pred)
return cls_pred, loc_pred
Please note that a lot of the code is inspired in @balancap repository. I merely re-wrote the training/testing pipelines and model in Keras.
I hope it is of some help.
Good luck!
@natlachaman Thank you alot for providing your code. I tried to implement your dssd layers in Tensorflow and I have some issues with tensor shapes.
The code I'm using is:
def dssd(net, end_points, featLayer):
"""
Adding deconvolutional layers to SSD, based on
https://arxiv.org/pdf/1701.06659.pdf
and
https://github.com/balancap/SSD-Tensorflow/issues/116
"""
dssd_layers = []
print(featLayer)
for i, layer in enumerate(reversed(featLayer)):
print(layer)
if i < 1:
dssd_layers.append(end_points[layer])
else:
end_point = 'dssd_block'+str(i)
# Deconvolutional Module
with tf.variable_scope(end_point):
# deconvolution path
upconv = slim.convolution2d_transpose(dssd_layers[-1], 512, [2, 2], scope='conv2x2')
upconv = slim.conv2d(upconv, 512, [3, 3], scope='conv3x3', padding='SAME')
upconv = slim.batch_norm(upconv, fused=True)
# ssd layer path
fromFeatLayer = slim.conv2d(end_points[layer], 512, [3, 3], scope='conv3x3_1', padding='SAME')
fromFeatLayer = slim.batch_norm(fromFeatLayer, fused=True)
fromFeatLayer = tf.nn.relu(fromFeatLayer)
fromFeatLayer = slim.conv2d(fromFeatLayer, 512, [3, 3], scope='conv3x3_2', padding='SAME')
fromFeatLayer = slim.batch_norm(fromFeatLayer, fused=True)
print(upconv.get_shape())
print(fromFeatLayer.get_shape())
# merge and output
merge = tf.multiply(upconv, fromFeatLayer)
print(merge.get_shape())
#net['dssd_{}'.format(layer)] = tf.nn.relu(merge)
net = tf.nn.relu(merge)
end_points[end_point] = net
#dssd_layers.append('dssd_{}'.format(layer))
dssd_layers.append(end_points[end_point])
return net, end_points, dssd_layers
net = Output of the complete network with Resnet50+ssd end_points = list of network endpoints for every layer featLayer = list of layer names, which should be used for getting the tensors to calculate the bbs: ['resnet_v2_50/block2/unit_3/bottleneck_v2/conv3', 'ssd_block7', 'ssd_block8', 'ssd_block9', 'ssd_block10']
I only use ssd blocks up to 10. When I run it, I get the following error:
Traceback (most recent call last):
File "train_model.py", line 29, in <module>
ssd_trainer.start_training()
File "D:\Devel\CNN_Depth_Estimation\Multi2\modular_SSD_tensorflow\trainer\trainer.py", line 156, in start_training
predictions, localisations, logits, depths, end_points = self.g_ssd.get_model(image, self.keep_probs, self.loadDepth)
File "D:\Devel\CNN_Depth_Estimation\Multi2\modular_SSD_tensorflow\ssd\ssdmodel.py", line 90, in get_model
net, end_points, bbox_layers = self._dssd_blocks(net, end_points, self.params.feature_layers)
File "D:\Devel\CNN_Depth_Estimation\Multi2\modular_SSD_tensorflow\ssd\ssd_blocks.py", line 311, in dssd
merge = tf.multiply(upconv, fromFeatLayer)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\math_ops.py", line 321, in multiply
return gen_math_ops._mul(x, y, name)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 3100, in _mul
"Mul", x=x, y=y, name=name)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3162, in create_op
compute_device=compute_device)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3208, in _create_op_helper
set_shapes_for_outputs(op)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 2427, in set_shapes_for_outputs
return _set_shapes_for_outputs(op)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 2400, in _set_shapes_for_outputs
shapes = shape_func(op)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 2330, in call_with_requiring
return call_cpp_shape_fn(op, require_shape_fn=True)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 627, in call_cpp_shape_fn
require_shape_fn)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 691, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Dimensions must be equal, but are 3 and 6 for 'dssd/dssd_block2/Mul' (op: 'Mul') with input shapes: [32,3,3,512], [32,6,6,512].
When printing out some layer names and shapes, I get:
ssd_block10
ssd_block9
(32, 1, 1, 512)
(32, 3, 3, 512)
(32, 3, 3, 512)
ssd_block8
(32, 3, 3, 512)
(32, 6, 6, 512)
This means, doing tf.multiply(ssd_block10, ssd_block9) works fine, since ssd_block10 shapes can be broadcasted, but doing tf.multiply(ssd_block9, ssd_block8) fails.
So, didin't you got the same error? Or does Keras work differently? Could you please help me?
Thanks.
Hi @Cuky88 I haven't used this code in a couple of months now. I tried running it again today with different detection layers (aka ssd layers) and indeed run into the same problem as you did, on the same line of code:
Traceback (most recent call last):
File "train.py", line 516, in <module>
train_function(args)
File "train.py", line 375, in train_function
ssd_net, model = model_settings(args)
File "train.py", line 348, in model_settings
model = ssd_net.net()
File "/home/tu/knaxq01/Deployment/ssd/network.py", line 133, in net
args=self.args)
File "/home/tu/knaxq01/Deployment/ssd/network.py", line 615, in ssd_net
deconvolutional_module(net, dssd_layers[-1], layer)
File "/home/tu/knaxq01/Deployment/ssd/network.py", line 448, in deconvolutional_module
z = Multiply()([x, y])
File "/opt/bwhpc/common/cs/keras/2.1.0-tensorflow-1.4-python-3.5/lib/python3.5/site-packages/keras/engine/topology.py", line 578, in __call__
self.build(input_shapes)
File "/opt/bwhpc/common/cs/keras/2.1.0-tensorflow-1.4-python-3.5/lib/python3.5/site-packages/keras/layers/merge.py", line 84, in build
output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
File "/opt/bwhpc/common/cs/keras/2.1.0-tensorflow-1.4-python-3.5/lib/python3.5/site-packages/keras/layers/merge.py", line 55, in _compute_elemwise_op_output_shape
str(shape1) + ' ' + str(shape2))
ValueError: Operands could not be broadcast together with shapes (5, 5, 512) (8, 8, 512)
I didn't have much time today, but I'll look into it tomorrow and get back to you. But at least you know is not a Tensorflow/Keras issue, there is something wrong in the implementation I posted. Sorry about that. Like I said, I'll try to get back to you again tomorrow. In the weekend, the latest.
@Cuky88 Ok, found it! By going through the paper again, it turns out that the ssd layer
that is passed to the deconvolutional module
needs to be 2 times the size of the dimensions of the deconv layer
passed to the module. Then, the deconvolutional path's output remains the same size as 'deconv layer' and the ssd path's output is reduced by half. So when both output tensors are multiply, they have the same dimensions.
Having said that, the error was in the "padding". I added an assertion statement that will warn you in case you the condition explained above is not met:
def deconvolutional_module(net, deconv_layer, ssd_layer):
assert net[ssd_layer]._keras_shape[1] == 2 * net[deconv_layer]._keras_shape[1], 'deconv layer (HxWxD) needs to ' \
'be half the size of the ssd ' \
'layer (2Wx2HxD)'
# deconvolution path
x = Conv2DTranspose(512, [2, 2], stride=2, padding='valid')(net[deconv_layer])
x = Conv2D(512, [3, 3], padding='same')(x)
x = BatchNormalization()(x)
# ssd layer path
y = Conv2D(512, [3, 3], padding='same')(net[ssd_layer])
y = BatchNormalization()(y)
y = Activation('relu')(y)
y = Conv2D(512, [3, 3], padding='same')(y)
y = BatchNormalization()(y)
# merge and output
z = Multiply()([x, y])
print(z)
net['dssd_{}'.format(ssd_layer)] = Activation('relu')(z)
Hope this helps. The translation to tensorflow should be straightforward. In case you're in doubt on how Keras handles this in the backend, just take a look here in the official repo.
Good luck! If any problems again, just hit me up!
@natlachaman Great job! So, as far as I understand, you just changed the padding to valid in the ssd layer path right? If I do this in Tensorflow, I get the following error on the first deconv module with ssd_block10 & ssd_block9:
Traceback (most recent call last):
File "train_model.py", line 29, in <module>
ssd_trainer.start_training()
File "D:\Devel\CNN_Depth_Estimation\Multi2\modular_SSD_tensorflow\trainer\trainer.py", line 156, in start_training
predictions, localisations, logits, depths, end_points = self.g_ssd.get_model(image, self.keep_probs, self.loadDepth)
File "D:\Devel\CNN_Depth_Estimation\Multi2\modular_SSD_tensorflow\ssd\ssdmodel.py", line 90, in get_model
net, end_points, bbox_layers = self._dssd_blocks(net, end_points, self.params.feature_layers)
File "D:\Devel\CNN_Depth_Estimation\Multi2\modular_SSD_tensorflow\ssd\ssd_blocks.py", line 304, in dssd
fromFeatLayer = slim.conv2d(fromFeatLayer, 512, [3, 3], scope='conv3x3_2', padding='VALID')
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\framework\python\ops\arg_scope.py", line 182, in func_with_args
return func(*args, **current_args)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\layers\python\layers\layers.py", line 1057, in convolution
outputs = layer.apply(inputs)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\layers\base.py", line 762, in apply
return self.__call__(inputs, *args, **kwargs)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\layers\base.py", line 652, in __call__
outputs = self.call(inputs, *args, **kwargs)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\layers\convolutional.py", line 167, in call
outputs = self._convolution_op(inputs, self.kernel)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\nn_ops.py", line 838, in __call__
return self.conv_op(inp, filter)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\nn_ops.py", line 502, in __call__
return self.call(inp, filter)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\nn_ops.py", line 190, in __call__
name=self.name)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\ops\gen_nn_ops.py", line 725, in conv2d
data_format=data_format, dilations=dilations, name=name)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3162, in create_op
compute_device=compute_device)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 3208, in _create_op_helper
set_shapes_for_outputs(op)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 2427, in set_shapes_for_outputs
return _set_shapes_for_outputs(op)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 2400, in _set_shapes_for_outputs
shapes = shape_func(op)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\ops.py", line 2330, in call_with_requiring
return call_cpp_shape_fn(op, require_shape_fn=True)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 627, in call_cpp_shape_fn
require_shape_fn)
File "C:\ProgramData\Miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\common_shapes.py", line 691, in _call_cpp_shape_fn_impl
raise ValueError(err.message)
ValueError: Negative dimension size caused by subtracting 3 from 1 for 'dssd/dssd_block1/conv3x3_2/Conv2D' (op: 'Conv2D') with input shapes: [32,1,1,512], [3,3,512,512].
I'm not sure how to solve this kind of problems. I'm not deeply familiar with Tensorflow. I tried changing padding to valid in "ssd layer path" and also added padding of 2x2 to deconv path, but nothing works. Can you please tell me, if you changed something in the ssd_blocks?
So I'm not sure if I understood. Does this mean, that I need to adjust the sizes of the ssd blocks? Those are define in:
def ssd321(net, end_points):
# block 6: 3x3 conv
net = slim.conv2d(net, 1024, [3, 3], rate=6, scope='conv6')
net = slim.batch_norm(net, fused=True)
net = custom_layers.dropout_with_noise(net)
end_points['ssd_block6'] = net
# block 7: 1x1 conv
net = slim.conv2d(net, 1024, [1, 1], scope='conv7')
net = slim.batch_norm(net, fused=True)
net = custom_layers.dropout_with_noise(net)
end_points['ssd_block7'] = net
# block 8/9/10: 1x1 and 3x3 convolutions with stride 2 (except lasts)
end_point = 'ssd_block8'
with tf.variable_scope(end_point):
net = slim.conv2d(net, 256, [1, 1], scope='conv1x1')
net = slim.batch_norm(net, fused=True)
net = custom_layers.dropout_with_noise(net)
net = custom_layers.pad2d(net, pad=(1, 1))
net = slim.conv2d(net, 512, [3, 3], stride=2, scope='conv3x3', padding='VALID')
net = slim.batch_norm(net, fused=True)
net = custom_layers.dropout_with_noise(net)
end_points[end_point] = net
end_point = 'ssd_block9'
with tf.variable_scope(end_point):
net = slim.conv2d(net, 128, [1, 1], scope='conv1x1')
net = slim.batch_norm(net, fused=True)
net = custom_layers.dropout_with_noise(net)
net = custom_layers.pad2d(net, pad=(1, 1))
net = slim.conv2d(net, 256, [3, 3], stride=2, scope='conv3x3', padding='VALID')
net = slim.batch_norm(net, fused=True)
net = custom_layers.dropout_with_noise(net)
end_points[end_point] = net
end_point = 'ssd_block10'
with tf.variable_scope(end_point):
net = slim.conv2d(net, 128, [1, 1], scope='conv1x1')
net = slim.batch_norm(net, fused=True)
net = custom_layers.dropout_with_noise(net)
net = slim.conv2d(net, 256, [3, 3], scope='conv3x3', padding='VALID')
net = slim.batch_norm(net, fused=True)
net = custom_layers.dropout_with_noise(net)
end_points[end_point] = net
return net, end_points
The shapes of the layers used from Resnet and SSD are:
Resnet_v2_50/block2/unit_3/bottleneck_v2/conv3 (16, 41, 41, 512)
ssd_block6 (16, 11, 11, 1024)
ssd_block7 (16, 11, 11, 1024)
ssd_block8 (16, 6, 6, 512)
ssd_block9 (16, 3, 3, 256)
ssd_block10 (16, 1, 1, 256)
Thank you alot for helping out. I appreciate it.
@Cuky88 You're right, is quite hard to make the output shapes match. First of all, again I want to apologize for a mistake in my previous comment. I got it exactly reversed: the deconvolutional module outputs feature maps of the size of the input in the ssd path. Therefore, the deconvolutional path needs to upsample to twice the size of its input in order for the multiplication at the end to work out and to output the right dimension. Take a look at this illustration from the original paper so you'll know what I mean.
Now, that only works if the input in the ssd path is exactly double the size of the input in the deconvolutional path and if the stride of the Conv2DTranspose() is 2 (I updated my previous reply to meet these criterion). There are probably other ways to make these two paths output feature maps of the same size (by means of zero padding, maybe), but this seemed to me like the cleanest way. There is also no guidance on this in the paper and no implementation that I can refer to. So I would suggest either making sure your input arguments meet the "twice the size" criteria, for instance,
block_6 (32x32xD)
block_8_2 (16x16xD)
block_9_2 (8x8xD)
block_10_2 (4x4xD)
block_11_2 (2x2xD)
block_12_2 (1x1xD)
you can do this by padding them just before the deconvolutional module is applied to them.
Also, your ssd_block9
has a (3x3)
feature map that when passed through the ssd path, through the first convolution, is reduced to a (1x1)
feature map. When applied the second convolution, it rises a ValuerError
since it can't perform a 3x3
convolution on a 1 pixel/feature feature map. That's why you get a negative dimension. But don't worry about it too much, now that we know that the ssd path performs convolutions with padding="same"
that shouldn't happen anymore. Shape mismatch problems will arise from the deconvolutional path :(
Sorry that I can't be of more help. I'm not sure myself how to exactly design the network given only some descriptors and illustrations from the paper. However, I can recommend this great guide on the arithmetic of convolutions in deep learning. It helped me to figure out output shapes.
Let me know if you managed to make it work. Good luck!
@natlachaman No need to apologize for anything, your help was awesome, more than I asked :) Everything is working fine now, I had to adjust all the layer sizes. But now SSD seems to learn better and the berHu loss converges good.
Would you mind having a quick look at my berHu implementation here. I'm not sure if it's correct.
Thnak you for taking time.
Cheers
This paper looks interesting too: https://arxiv.org/pdf/1709.05054.pdf
@natlachaman Hi, have you finished your training work and how is the result?
@qimw Hi! I have indeed finished. It worked out quite well in the end. Here is a link to my Masters' thesis were you can find the results.
I am trying to set up SSD on my own dataset. The dataset is composed of images of particles with a size of 2048x1536 and the objects have sizes between 25x25 to 60x60. It has only one class.
I tried to fine tune the given model with my new objects (I generated xml for my database with the same format as PascalVOC). The training runs, the loss is between 100 and 1 but it never really converges even after 500 000 steps. Then, when I evaluate the model, the mAP is 0, I tried it on one image from the training set with the code of the notebook and it really finds nothing (even with a threshold of 0.01).
I understood that durring the process of SSD, images are resized to 300x300 or 512x512 depending on the model. In my case, this resizing reduces a lot the size of the objects (their size is reduce by a factor 5-7). So I thought that because of the resizing, the objects were too small to be detected.
So,I tried to split my images from 1 big to 35 smaller ones with a 300x300 size. And the results are still the same.
Do someone try to use SSD on small objects? I know it is mentionned in the original paper that small objects are harder to detect but if someone has an idea, he is welcome.