mvoelk / ssd_detectors

SSD-based object and text detection with Keras, SSD, DSOD, TextBoxes, SegLink, TextBoxes++, CRNN
MIT License
302 stars 85 forks source link

Using SL_train.py to train with my own dataset (VOC format) #29

Open wander1985 opened 5 years ago

wander1985 commented 5 years ago

Hi, I am using SL_train.ipynb to train with my own VOC format dataset on Windows10. I used LabelImg to label the groundtruth annotation, and used data_voc.py to generate the pickle file. I've only used 5 images (3 for training, 1 for val, 1 for test). I set the batch size to 1. But the training process kept raising the following InvalidArgumentError after passing through the first image. Can you help? Thanks.

1/4 [======>.......................] - ETA: 1:29 - loss: 20.9251 - seg_conf_loss: 3.7705 - seg_loc_loss: 10.3475 - link_conf_loss: 6.8071 - num_pos_seg: 28.0000 - num_neg_seg: 84.0000 - pos_seg_conf_loss: 3.3756 - neg_seg_conf_loss: 3.9021 - pos_link_conf_loss: 2.0223 - neg_link_conf_loss: 8.4020 - seg_precision: 0.0000e+00 - seg_recall: 0.0000e+00 - seg_accuracy: 0.0000e+00 - seg_fmeasure: 0.0000e+00 - link_precision: 0.0000e+00 - link_recall: 0.0000e+00 - link_accuracy: 0.0000e+00 - link_fmeasure: 0.0000e+00
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-14-29316809ab04> in <module>()
     46         workers=1,
     47         #use_multiprocessing=False,
---> 48         initial_epoch=initial_epoch,
     49         #pickle_safe=False, # will use threading instead of multiprocessing, which is lighter on memory use but slower
     50         )

d:\Anaconda3\envs\text_detection\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

d:\Anaconda3\envs\text_detection\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1413             use_multiprocessing=use_multiprocessing,
   1414             shuffle=shuffle,
-> 1415             initial_epoch=initial_epoch)
   1416 
   1417     @interfaces.legacy_generator_methods_support

d:\Anaconda3\envs\text_detection\lib\site-packages\keras\engine\training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    211                 outs = model.train_on_batch(x, y,
    212                                             sample_weight=sample_weight,
--> 213                                             class_weight=class_weight)
    214 
    215                 outs = to_list(outs)

d:\Anaconda3\envs\text_detection\lib\site-packages\keras\engine\training.py in train_on_batch(self, x, y, sample_weight, class_weight)
   1213             ins = x + y + sample_weights
   1214         self._make_train_function()
-> 1215         outputs = self.train_function(ins)
   1216         return unpack_singleton(outputs)
   1217 

d:\Anaconda3\envs\text_detection\lib\site-packages\keras\backend\tensorflow_backend.py in __call__(self, inputs)
   2664                 return self._legacy_call(inputs)
   2665 
-> 2666             return self._call(inputs)
   2667         else:
   2668             if py_any(is_tensor(x) for x in inputs):

d:\Anaconda3\envs\text_detection\lib\site-packages\keras\backend\tensorflow_backend.py in _call(self, inputs)
   2634                                 symbol_vals,
   2635                                 session)
-> 2636         fetched = self._callable_fn(*array_vals)
   2637         return fetched[:len(self.outputs)]
   2638 

d:\Anaconda3\envs\text_detection\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args)
   1452         else:
   1453           return tf_session.TF_DeprecatedSessionRunCallable(
-> 1454               self._session._session, self._handle, args, status, None)
   1455 
   1456     def __del__(self):

d:\Anaconda3\envs\text_detection\lib\site-packages\tensorflow\python\framework\errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    517             None, None,
    518             compat.as_text(c_api.TF_Message(self.status.status)),
--> 519             c_api.TF_GetCode(self.status.status))
    520     # Delete the underlying status object from memory otherwise it stays alive
    521     # as there is a reference to status from this from the traceback due to

InvalidArgumentError: Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero
     [[Node: training_1/Adam/gradients/loss_1/predictions_loss/TopKV2_grad/Reshape = Reshape[T=DT_INT32, Tshape=DT_INT32, _class=["loc:@train...rseToDense"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](loss_1/predictions_loss/TopKV2:1, training_1/Adam/gradients/loss_1/predictions_loss/TopKV2_grad/stack)]]
mvoelk commented 5 years ago

How does your GTUtility class look like?

You may also want to read the other dataset related issues here...

wander1985 commented 5 years ago

The following is the code of my GTUtility class. I tried to modify the code based on your reply for another issue, https://github.com/mvoelk/ssd_detectors/issues/12#issuecomment-485686377.

class GTUtility(BaseGTUtility):
    def __init__(self, data_path, polygon=True):
        self.data_path = data_path
        self.image_path = os.path.join(data_path, 'JPEGImages')
        self.gt_path = gt_path = os.path.join(self.data_path, 'Annotations')
        self.classes = ['Background', 'Text']
        classes_lower = [s.lower() for s in self.classes]

        self.image_names = []
        self.data = []
        for filename in os.listdir(gt_path):
            tree = ElementTree.parse(os.path.join(gt_path, filename))
            root = tree.getroot()
            boxes = []
            size_tree = root.find('size')
            img_width = float(size_tree.find('width').text)
            img_height = float(size_tree.find('height').text)
            image_name = root.find('filename').text
            for object_tree in root.findall('object'):
                class_name = object_tree.find('name').text
                class_idx = classes_lower.index(class_name)
                for box in object_tree.iter('bndbox'):
                    xmin = float(box.find('xmin').text) / img_width
                    ymin = float(box.find('ymin').text) / img_height
                    xmax = float(box.find('xmax').text) / img_width
                    ymax = float(box.find('ymax').text) / img_height
                    if polygon:
                        box = [xmin, ymin, xmin, ymax, xmax, ymax, xmax, ymin, 1]
                    else:
                        box = [xmin, ymin, xmax, ymax, 1]
                    boxes.append(box)
            boxes = np.asarray(boxes)
            self.image_names.append(image_name)
            self.data.append(boxes)

        self.init()

if __name__ == '__main__':
    gt_util = GTUtility('data/VOC2007')
    print(gt_util.classes)
    gt = gt_util.data
    print(gt)

    import pickle
    file_name = 'gt_util_voc2007.pkl'
    print('save to %s...' % file_name)
    pickle.dump(gt_util, open(file_name, 'wb'))
    print('done')
mvoelk commented 5 years ago
inputs, data = gt_util.sample_batch(1, 0)

What are the shapes you get?

wander1985 commented 5 years ago

The inputs.shape is (1, 512, 512, 3). The data is [array([[0.2275 , 0.44125, 0.2275 , 0.55875, 0.77125, 0.55875, 0.77125, 0.44125, 1. ]])]. data.shape raises AttributeError: 'list' object has no attribute 'shape'.

mvoelk commented 5 years ago

Seems okay... Does it work with larger batch size and more samples?

What is the shape of the model input? Should be (None, 512, 512, 3).

wander1985 commented 5 years ago

I changed batch size to 2 with 100 samples, but still got similar invalid argument error: InvalidArgumentError: Reshape cannot infer the missing input size for an empty tensor unless all specified input sizes are non-zero [[Node: training_4/Adam/gradients/loss_4/predictions_loss/TopKV2_grad/Reshape = Reshape[T=DT_INT32, Tshape=DT_INT32, _class=["loc:@train...rseToDense"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](loss_4/predictions_loss/TopKV2:1, training_4/Adam/gradients/loss_4/predictions_loss/TopKV2_grad/stack)]]

How to get the shape of the model input? I'm sorry, I am a newbie to this field.

mvoelk commented 5 years ago

model.input_shape

A piece of code would also be helpful.

wander1985 commented 5 years ago

Thanks. I used print(model.input_shape) below # SegLink + DenseNet model = DSODSL512()

The shape is (None, 512, 512, 3). It seems as it should be.

mvoelk commented 5 years ago

It probably has something to do with the negative samples in the hard negative mining of the SegLinkLoss. You may not get any negative samples in the local ground truth at all. Did you tried SegLinkFocalLoss?

With code I meant, what does your SL_train.ipynb look like?

wander1985 commented 5 years ago

I changed the SegLinkLoss to loss = SegLinkFocalLoss(lambda_segments=1.0, lambda_offsets=1.0, lambda_links=1.0) and it works as magic. Thanks so much.

But why I didn't get negative samples in the local ground truth? Is it because I didn't label my ground truth correctly? The following is how I label my ground truth using LabelImg. Am I doing something wrong? For example, I have an image below needs to be labeled, then I give all the text ("Campus" and "Shop" in this example) the same class name "text" (as shown in the screenshot below). 62883086-457dcc80-bd01-11e9-8443-17f71daf7140 But I am wondering where to label the exact letters in the text (i.e. "Campus" and "Shop")? Or I don't need to label the exact letters at all?

mvoelk commented 5 years ago

The sequence label is placed in the text attribute of the GTUtility and is only required for the recognition stage. For more details, see data_svt.py.