DetectionTeamUCAS / FPN_Tensorflow

This is a tensorflow re-implementation of Feature Pyramid Networks for Object Detection.
https://github.com/DetectionTeamUCAS/FPN_Tensorflow
MIT License
347 stars 132 forks source link

训练从第二步开始totalLoss,rpnLocLoss为NAN #121

Closed khinggan closed 4 years ago

khinggan commented 4 years ago

有一个问题请教以下楼主。 我用IAM手写体数据集训练。我编辑完原始格式之后gt格式是txt。

165.0,661.5,215.5,681.5 166.5,388.0,242.0,412.5 167.0,481.5,282.5,502.5 170.5,567.5,173.0,594.0

分别是x1, y1, x2, y2;(x1,y1) bounding box左上角位置,(x2,y2)为右下角位置。

通过修改data/io/convert_data_to_tfrecord.py中的解析xml过程改为i解析txt得到新的.tfrecord 文件。 convert_iam_data_to_tfrecord.py内容

def read_txt_gtbox_and_label(label_path):
    img_width = 1233
    img_height = 1762

    box_list = []
    with open(label_path, 'r') as f:
        content = f.readlines()
        content = [line.strip('\n') for line in content]

        for line in content:
            coords = line.split(',')
            tmp_box = []
            for coord in coords:
                tmp_box.append(int(float(coord)))
            tmp_box.append(NAME_LABEL_MAP['text'])
            box_list.append(tmp_box)

    gtbox_label = np.array(box_list, dtype=np.int32)

    xmin, ymin, xmax, ymax, label = gtbox_label[:, 0], gtbox_label[:, 1], gtbox_label[:, 2], gtbox_label[:, 3], \
                                    gtbox_label[:, 4]

    gtbox_label = np.transpose(np.stack([xmin, ymin, xmax, ymax, label], axis=0))  # [xmin, ymin, xmax, ymax, label]

    return img_height, img_width, gtbox_label

def convert_iam_to_tfrecord():
    label_path = FLAGS.iam_dir + FLAGS.label_dir
    image_path = FLAGS.iam_dir + FLAGS.image_dir
    save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord'
    mkdir(FLAGS.save_dir)

    # writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB)
    # writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options)
    writer = tf.python_io.TFRecordWriter(path=save_path)
    for count, txt in enumerate(glob.glob(label_path + '/*.txt')):
        # to avoid path error in different development platform
        txt = txt.replace('\\', '/')

        img_name = txt.split('/')[-1].split('.')[0] + FLAGS.img_format
        img_path = image_path + '/' + img_name

        if not os.path.exists(img_path):
            print('{} is not exist!'.format(img_path))
            continue

        img_height, img_width, gtbox_label = read_txt_gtbox_and_label(txt)

        # img = np.array(Image.open(img_path))
        img = cv2.imread(img_path)[:, :, ::-1]

        feature = tf.train.Features(feature={
            # do not need encode() in linux
            'img_name': _bytes_feature(img_name.encode('utf-8')),
            # 'img_name': _bytes_feature(img_name),
            'img_height': _int64_feature(img_height),
            'img_width': _int64_feature(img_width),
            'img': _bytes_feature(img.tostring()),
            'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()),
            'num_objects': _int64_feature(gtbox_label.shape[0])
        })

        example = tf.train.Example(features=feature)

        writer.write(example.SerializeToString())

        view_bar('Conversion progress', count + 1, len(glob.glob(label_path + '/*.txt')))

    print('\nConversion is complete!')

if __name__ == '__main__':
    convert_iam_to_tfrecord()

但是训练过程中从第二步开始,totalLoss,rpnLocLoss为NAN。上面log是

2019-12-10 21:44:04: step10 image_name:b'a01-011u.png' |
rpn_loc_loss:0.8125103712081909 | rpn_cla_loss:0.6350864768028259 | rpn_total_loss:1.447596788406372 | fast_rcnn_loc_loss:0.5327670574188232 | fast_rcnn_cla_loss:0.5486091375350952 | fast_rcnn_total_loss:1.0813761949539185 | total_loss:3.1570582389831543 | per_cost_time:2.8225905895233154s /home/khinggan/work/FPN_Tensorflow/libs/box_utils/encode_and_decode.py:156: RuntimeWarning: invalid value encountered in log t_w = np.log(w/reference_w) /home/khinggan/work/FPN_Tensorflow/libs/box_utils/encode_and_decode.py:157: RuntimeWarning: invalid value encountered in log t_h = np.log(h/reference_h) 2019-12-10 21:44:12: step20 image_name:b'a01-030.png' | rpn_loc_loss:nan | rpn_cla_loss:0.7283592224121094 | rpn_total_loss:nan | fast_rcnn_loc_loss:0.727269172668457 | fast_rcnn_cla_loss:0.5448232889175415 | fast_rcnn_total_loss:1.2720924615859985 | total_loss:nan | per_cost_time:0.5033695697784424s 2019-12-10 21:44:17: step30 image_name:b'a01-049x.png' |
rpn_loc_loss:nan | rpn_cla_loss:0.6896220445632935 | rpn_total_loss:nan | fast_rcnn_loc_loss:0.0 | fast_rcnn_cla_loss:1.060610294342041 | fast_rcnn_total_loss:1.060610294342041 | total_loss:nan | per_cost_time:0.5133085250854492s

我看了以下issue75的解决方法,我的生成脚本错了吗?通过调试发现输出格式是[xmin, ymin, xmax, ymax, label]. 楼主有什么好的建议吗?