MaybeShewill-CV / CRNN_Tensorflow

Convolutional Recurrent Neural Networks(CRNN) for Scene Text Recognition
MIT License
1.03k stars 388 forks source link

请问下有没有改成批量预测的思路 #316

Closed LXLun closed 5 years ago

MaybeShewill-CV commented 5 years ago

@LXLun 你自己放一个batch的数据进去就好了:)

LXLun commented 5 years ago

每次都必须在sess启动前加载固定的input数据,就意味着不能灵活的只是实例化一个sess然后input是变化的?

LXLun commented 5 years ago

有没有什么方案

LXLun commented 5 years ago

我正在把它做成一个服务

MaybeShewill-CV commented 5 years ago

@LXLun 不明白你说的变化是指什么 是图像的数量还是图像的尺寸还是什么==!

LXLun commented 5 years ago

都会变化,比如图片的高度都resize成32,然后宽度会resize成当前batch内宽度最大的图片的尺寸,比宽度最大图片小的用补空白的方法,然后每次batch都不一样,就是说代码里的inputdata每次不一样,然后我只想在服务启动的时候实例化模型,但是启动的时候要加载进inputdata,后面如果重新改变inputdata放入sess会出现

Traceback (most recent call last): File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1092, in _run subfeed, allow_tensor=True, allow_operation=False) File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3490, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation) File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3569, in _as_graph_element_locked raise ValueError("Tensor %s is not an element of this graph." % obj) ValueError: Tensor Tensor("input:0", shape=(1, 32, 216, 3), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "inference_crnn.py", line 282, in rets = instance.predict_batch(image_list) File "inference_crnn.py", line 219, in predict_batch texts = self.predict_batch_from_left_to_right(batch_images) File "inference_crnn.py", line 199, in predict_batch_from_left_to_right preds = self.session_hor.run(decodes, feed_dict={inputdata: [batch_array[0]]}) File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run run_metadata_ptr) File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1095, in _run 'Cannot interpret feed_dict key as Tensor: ' + e.args[0]) TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input:0", shape=(1, 32, 216, 3), dtype=float32) is not an element of this graph.

zonasw commented 5 years ago

你可以把inputdata定义成这样: inputdata = tf.placeholder(dtype=tf.float32, shape=[1, None, None, CFG.ARCH.INPUT_CHANNELS], name='input') 另外,注释掉crnn_net.py中第154行的 shape[1] == 1

MaybeShewill-CV commented 5 years ago

@LXLun 那你用placeholder占位符吧:)

LXLun commented 5 years ago

你可以把inputdata定义成这样: inputdata = tf.placeholder(dtype=tf.float32, shape=[1, None, None, CFG.ARCH.INPUT_CHANNELS], name='input') 另外,注释掉crnn_net.py中第154行的 shape[1] == 1

我inputdata是这样的

class OCREngine(object): def init(self, weights_path):

config tf session

    sess_config = tf.ConfigProto(allow_soft_placement=True)
    sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
    sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH

    self.graph_hor = tf.Graph()
    self.session_hor = tf.Session(graph=self.graph_hor, config=sess_config)
    with self.session_hor.as_default():
        with self.session_hor.graph.as_default():
            # definite the compute graph
            self.inputdata = tf.placeholder(
                dtype=tf.float32,
                shape=[None, 32, None, CFG.ARCH.INPUT_CHANNELS],
                name='input'
            )

            self.codec = tf_io_pipline_fast_tools.CrnnFeatureReader(
                char_dict_path=char_dict_path,
                ord_map_dict_path=ord_map_dict_path
            )

            self.net = crnn_net.ShadowNet(
                phase='test',
                hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                num_classes=CFG.ARCH.NUM_CLASSES
            )

            self.inference_ret = self.net.inference(
                inputdata=self.inputdata,
                name='shadow_net',
                reuse=False
            )

            saver_hor = tf.train.Saver()
            saver_hor.restore(sess=self.session_hor, save_path=weights_path)

def test_predict(self,images):
    # pdf_image_text_areas = []
    # new_heigth = 32
    # print(images[0].shape[1])

    pdf_image_text_area = np.array(images[0], np.float32) / 127.5 - 1.0

    inputdata = tf.placeholder(
        dtype=tf.float32,
        shape=[len(images), 32, images[0].shape[1], CFG.ARCH.INPUT_CHANNELS],
        name='input'
    )

    net = crnn_net.ShadowNet(
        phase='test',
        hidden_nums=CFG.ARCH.HIDDEN_UNITS,
        layers_nums=CFG.ARCH.HIDDEN_LAYERS,
        num_classes=CFG.ARCH.NUM_CLASSES
    )

    inference_ret = net.inference(
        inputdata=inputdata,
        name='shadow_net',
        reuse=False
    )

    decodes, _ = tf.nn.ctc_beam_search_decoder(
        inputs=inference_ret,
        sequence_length=int(images[0].shape[1] / 4) * np.ones(1) * np.ones(1),
        merge_repeated=False,
        beam_width=1
    )

    preds = self.session_hor.run(decodes, feed_dict={inputdata: [pdf_image_text_area]})
    print(preds)

然后随便用一张图片测试test_predict方法会出现

Traceback (most recent call last): File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1092, in _run subfeed, allow_tensor=True, allow_operation=False) File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3490, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation) File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3569, in _as_graph_element_locked raise ValueError("Tensor %s is not an element of this graph." % obj) ValueError: Tensor Tensor("input:0", shape=(1, 32, 216, 3), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "inference_crnn.py", line 281, in rets = instance.test_predict(image_list) File "inference_crnn.py", line 262, in test_predict preds = self.session_hor.run(decodes, feed_dict={inputdata: [pdf_image_text_area]}) File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 929, in run run_metadata_ptr) File "/home/luoxilun/lxl_new/train_venv/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1095, in _run 'Cannot interpret feed_dict key as Tensor: ' + e.args[0]) TypeError: Cannot interpret feed_dict key as Tensor: Tensor Tensor("input:0", shape=(1, 32, 216, 3), dtype=float32) is not an element of this graph.

MaybeShewill-CV commented 5 years ago

@LXLun 我建议你先熟悉一下tensorflow. 你这inputdata op都不在你定义的计算图中肯定会出错了:)

LXLun commented 5 years ago

可能是我改糊涂了,看了下,改好了,被前面的错误绕进去了

LXLun commented 5 years ago

我后面可以提供个批量识别的代码,需要合并进你的分支里么

LXLun commented 5 years ago

@LXLun 我建议你先熟悉一下tensorflow. 你这inputdata op都不在你定义的计算图中肯定会出错了:)

谢谢点醒了我,哈哈哈

MaybeShewill-CV commented 5 years ago

@LXLun 如果感兴趣可以pull request. 这个没有问题的话 我关闭了:)

LXLun commented 5 years ago

还是有一个问题,我放进preds = self.session_hor.run(decodes, feed_dict={self.inputdata: [images]})的数据实际上是(2, 32, 216, 3),但是一旦执行了preds = self.session_hor.run(decodes, feed_dict={self.inputdata: [images]})就会报错ValueError: Cannot feed value of shape (1, 2, 32, 216, 3) for Tensor 'input:0', which has shape '(?, 32, ?, 3)',这可能是什么原因?

LXLun commented 5 years ago

为何前面会加上个1的维度

MaybeShewill-CV commented 5 years ago

@LXLun 检查你的代码 那里增加过维度吧:)

LXLun commented 5 years ago
class OCREngine(object):
    def __init__(self, weights_path):
        # config tf session
        sess_config = tf.ConfigProto(allow_soft_placement=True)
        sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
        sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH

        self.graph_hor = tf.Graph()
        self.session_hor = tf.Session(graph=self.graph_hor, config=sess_config)
        with self.session_hor.as_default():
            with self.session_hor.graph.as_default():
                # definite the compute graph
                self.inputdata = tf.placeholder(
                    dtype=tf.float32,
                    shape=[None, 32, None, CFG.ARCH.INPUT_CHANNELS],
                    name='input'
                )

                self.codec = tf_io_pipline_fast_tools.CrnnFeatureReader(
                    char_dict_path=char_dict_path,
                    ord_map_dict_path=ord_map_dict_path
                )

                self.net = crnn_net.ShadowNet(
                    phase='test',
                    hidden_nums=CFG.ARCH.HIDDEN_UNITS,
                    layers_nums=CFG.ARCH.HIDDEN_LAYERS,
                    num_classes=CFG.ARCH.NUM_CLASSES
                )

                self.inference_ret = self.net.inference(
                    inputdata=self.inputdata,
                    name='shadow_net',
                    reuse=False
                )

                saver_hor = tf.train.Saver()
                saver_hor.restore(sess=self.session_hor, save_path=weights_path)

        def test_predict(self,images):
        # pdf_image_text_areas = []
        # new_heigth = 32
        # print(images[0].shape[1])

        # pdf_image_text_area = np.array(images[7], np.float32) / 127.5 - 1.0

        decodes, _ = tf.nn.ctc_beam_search_decoder(
            inputs=self.inference_ret,
            sequence_length=int(images[0].shape[1] / 4) * np.ones(2),
            merge_repeated=False,
            beam_width=1
        )

        preds = self.session_hor.run(decodes, feed_dict={self.inputdata: [images]})
        # if number_remained > 0 and step == number_batches:
        #     preds = preds[:number_remained]
        #     chars_count = chars_count[:number_remained]
        for i in range(len(preds)):
            # count = int(chars_count[i])
            pred = self.codec.sparse_tensor_to_str(preds[i])
            print(pred)

if __name__ == '__main__':
    instance = OCREngine('model/crnn_gag/shadownet_2019-07-10-10-42-12.ckpt-662000')
    img_dirs = 'data/test_data_2'
    image_path_list = os.listdir(img_dirs)
    image_list = []
    for img_p in image_path_list:
        img_path = os.path.join(img_dirs, img_p)
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = np.array(image, np.float32) / 127.5 - 1.0
        image_list.append(image)
        image_list.append(image)
        break

    rets = instance.test_predict(image_list)

这里随便选了张图片测试

MaybeShewill-CV commented 5 years ago

@LXLun 你自己debug一下不是很容易就知道了么 preds = self.session_hor.run(decodes, feed_dict={self.inputdata: [images]}) self.inputdate: images 不是[images] ==!