solivr / tf-crnn

TensorFlow convolutional recurrent neural network (CRNN) for text recognition
GNU General Public License v3.0
292 stars 98 forks source link

using TFRecords #14

Closed WenmuZhou closed 7 years ago

WenmuZhou commented 7 years ago

I'm trying to use TFRecord as the data input for the program, but I have encountered some problems。 the code used to convert img to TFRecord. I have test this code and work well

# -*- coding: utf-8 -*-
# @Time    : 2017/11/27 16:00
# @Author  : zhoujun
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
from PIL import Image
from tqdm import tqdm

def getFileName(path):
    return path.split('/')[-1]

def readLines(file_path):
    with open(file_path, 'r') as T:
        lines = T.readlines()
    return lines

def split_lines(src):
    lines = src
    label_record = {}
    for line in lines:
        name = line.split(' ')[0]
        label = line.split(' ')[1]
        label = label.split('\n')[0]
        label_record[name] = label
    return label_record

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def recordsCreater(label_file, dst_records):
    writer = tf.python_io.TFRecordWriter(dst_records)

    lines = readLines(label_file)
    label_record = split_lines(lines)
    index = 0

    pbar = tqdm(total=len(lines))
    for file_path, label in label_record.items():
        index = index + 1
        img = Image.open(file_path)
        image_raw = img.tobytes()

        cols = img.size[0]
        rows = img.size[1]
        depth = 3 if img.mode is 'RGB' else 1

        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'label': _bytes_feature(bytes(label, encoding = "utf8")  ),
            'image_raw': _bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())
        writer.flush()
        pbar.update(1)
    print("done!")
    writer.close()
    pbar.close()

# 读取二进制数据
def recordsReader(filename):
    # 创建文件队列,不限读取的数量
    filename_queue = tf.train.string_input_producer([filename])
    # create a reader from file queue
    reader = tf.TFRecordReader()
    # reader从文件队列中读入一个序列化的样本
    _, serialized_example = reader.read(filename_queue)
    # get feature from serialized example
    # 解析符号化的样本
    features = tf.parse_single_example(
        serialized_example,
        features={
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.string),
            'image_raw': tf.FixedLenFeature([], tf.string)
        }
    )
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    depth = tf.cast(features['depth'], tf.int32)
    image = tf.reshape(image, [height,width,depth])
    label = tf.cast(features['label'], tf.string)
    return image, label

def test_reader(recordsFile):
    image, label = recordsReader(recordsFile)
    with tf.Session() as sess:  # 开始一个会话
        init_op = tf.initialize_all_variables()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(1):
            image, label = sess.run([image, label])  # 在会话中取出image和label
            print(label)
            # img = Image.fromarray(example, 'RGB')  # 如果img是RGB图像
            # img = Image.fromarray(example)
            #
            # img.save('./' + '_'+'Label_' + str(l) + '.jpg')  # 存下图片
            Image._show(Image.fromarray(image))
        coord.request_stop()
        coord.join(threads)

if __name__ == '__main__':
    test_label_file, test_dst_records = "E:\\val1.csv", "E:\\val1.tfrecords"
    # train_label_file, train_dst_records = "../MNIST_data/mnist_train/train.txt", "../MNIST_data/mnist_train.tfrecords"
    # recordsCreater(test_label_file, test_dst_records)
    # recordsCreater(test_label_file, test_dst_records)
    test_reader(test_dst_records)

the code used as imput_fn

# -*- coding: utf-8 -*-
# @Time    : 2017/11/27 19:12
# @Author  : zhoujun
import tensorflow as tf
from src.data_handler import padding_inputs_width, augment_data
from src.config import CONST,Params,Alphabet

def input_fn(filename, is_training, params, batch_size=1, num_epochs=1):
    """A simple input_fn using the tf.data input pipeline."""

    def example_parser(serialized_example):
        """Parses a single tf.Example into image and label tensors."""
        features = tf.parse_single_example(
            serialized_example,
            features={
                'height': tf.FixedLenFeature([], tf.int64),
                'width': tf.FixedLenFeature([], tf.int64),
                'depth': tf.FixedLenFeature([], tf.int64),
                'label': tf.FixedLenFeature([], tf.string),
                'image_raw': tf.FixedLenFeature([], tf.string)
            }
        )
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        height = tf.cast(features['height'], tf.int32)
        width = tf.cast(features['width'], tf.int32)
        depth = tf.cast(features['depth'], tf.int32)
        image = tf.reshape(image, [height, width, depth])
        label = tf.cast(features['label'], tf.string)

        # Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
        # image = tf.cast(image, tf.float32) / 255 - 0.5
        # Data augmentation
        if is_training:
            image = augment_data(image)

        image, width = padding_inputs_width(image, params.input_shape, increment=CONST.DIMENSION_REDUCTION_W_POOLING)

        return {'images': image, 'images_widths': width, 'labels': label}, label

    dataset = tf.data.TFRecordDataset([filename])

    # Apply dataset transformations
    if is_training:
        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes have better performance. Because MNIST is
        # a small dataset, we can easily shuffle the full epoch.
        dataset = dataset.shuffle(buffer_size=1000)

    dataset = dataset.repeat(num_epochs)

    # Map example_parser over dataset, and batch results by up to batch_size
    dataset = dataset.map(example_parser).prefetch(batch_size)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    image, label = iterator.get_next()

    return image, label

if __name__ == '__main__':
    parameters = Params(eval_batch_size=128,
                        input_shape=(32, 304),
                        digits_only=False,
                        alphabet=Alphabet.CHINESECHAR_LETTERS_DIGITS_EXTENDED,
                        alphabet_decoding='same',
                        )

    next_batch = input_fn(filename='E:\\val1.tfrecords.', is_training=False,params=parameters,batch_size=2)

    # Now let's try it out, retrieving and printing one batch of data.
    # Although this code looks strange, you don't need to understand
    # the details.
    with tf.Session() as sess:
        first_batch = sess.run(next_batch)
    print(first_batch)

when test this code the output is

({'images': array([[[[ 252.        ,  252.        ,  252.        ],
         [  72.777771  ,   72.777771  ,   72.777771  ],
         [ 247.74073792,  247.74073792,  247.74073792],
         ..., 
         [ 250.14813232,  250.14813232,  250.14813232],
         [ 254.29629517,  254.29629517,  254.29629517],
         [ 252.77780151,  252.77780151,  252.77780151]],

        [[ 251.25      ,  251.25      ,  251.25      ],
         [ 237.222229  ,  237.222229  ,  237.222229  ],
         [ 252.40740967,  252.40740967,  252.40740967],
         ..., 
         [ 254.66670227,  254.66670227,  254.66670227],
         [ 248.66665649,  248.66665649,  248.66665649],
         [ 253.75      ,  253.75      ,  253.75      ]],

        [[ 246.5       ,  246.5       ,  246.5       ],
         [ 253.05555725,  253.05555725,  253.05555725],
         [ 248.53703308,  248.53703308,  248.53703308],
         ..., 
         [ 253.07406616,  253.07406616,  253.07406616],
         [ 255.        ,  255.        ,  255.        ],
         [ 254.722229  ,  254.722229  ,  254.722229  ]],

        ..., 
        [[ 251.75      ,  251.75      ,  251.75      ],
         [ 251.        ,  251.        ,  251.        ],
         [ 250.66667175,  250.66667175,  250.66667175],
         ..., 
         [ 254.61114502,  254.61114502,  254.61114502],
         [ 245.16665649,  245.16665649,  245.16665649],
         [ 247.58332825,  247.58332825,  247.58332825]],

        [[ 251.5       ,  251.5       ,  251.5       ],
         [ 249.85185242,  249.85185242,  249.85185242],
         [ 248.777771  ,  248.777771  ,  248.777771  ],
         ..., 
         [ 254.79632568,  254.79632568,  254.79632568],
         [ 250.53703308,  250.53703308,  250.53703308],
         [ 251.88890076,  251.88890076,  251.88890076]],

        [[ 255.        ,  255.        ,  255.        ],
         [ 246.87037659,  246.87037659,  246.87037659],
         [ 251.8888855 ,  251.8888855 ,  251.8888855 ],
         ..., 
         [ 249.12036133,  249.12036133,  249.12036133],
         [ 251.43518066,  251.43518066,  251.43518066],
         [ 251.44442749,  251.44442749,  251.44442749]]],

       [[[ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         ..., 
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ]],

        [[ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         ..., 
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ]],

        [[ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         ..., 
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ],
         [ 255.        ,  255.        ,  255.        ]],

        ..., 
        [[ 250.        ,  250.        ,  250.        ],
         [ 254.88095093,  254.88095093,  254.88095093],
         [ 254.73809814,  254.73809814,  254.73809814],
         ..., 
         [ 255.        ,  255.        ,  255.        ],
         [ 253.86904907,  253.86904907,  253.86904907],
         [ 250.50003052,  250.50003052,  250.50003052]],

        [[ 254.        ,  254.        ,  254.        ],
         [ 254.16665649,  254.16665649,  254.16665649],
         [ 253.16665649,  253.16665649,  253.16665649],
         ..., 
         [ 252.        ,  252.        ,  252.        ],
         [ 250.23809814,  250.23809814,  250.23809814],
         [ 255.        ,  255.        ,  255.        ]],

        [[ 254.25      ,  254.25      ,  254.25      ],
         [ 253.45237732,  253.45237732,  253.45237732],
         [ 251.5952301 ,  251.5952301 ,  251.5952301 ],
         ..., 
         [ 253.5       ,  253.5       ,  253.5       ],
         [ 253.75      ,  253.75      ,  253.75      ],
         [ 249.42855835,  249.42855835,  249.42855835]]]], dtype=float32), 'images_widths': array([108,  84]), 'labels': array([b'2760$854$1429$1224', b'1105$1232$1560'], dtype=object)}, array([b'2760$854$1429$1224', b'1105$1232$1560'], dtype=object))

however when I use the input_fn as the input of estimator.evaluate, I meet some error and I don't know where is error.

the code is

#!/usr/bin/env python
__author__ = 'zj'

import argparse
import os
import sys
import numpy as np

try:
    import better_exceptions
except ImportError:
    pass
import tensorflow as tf
from src.model2 import crnn_fn
from src.data_handler import data_loader
from src.config import Params, Alphabet
from src.input_utils import input_fn

def main(unused_argv):
    models_path = FLAGS.input_model_dir
    if not os.path.exists(models_path):
        assert FileNotFoundError

    models_list = [os.path.join(models_path, x[:-5]) for x in os.listdir(models_path) if x.endswith('.meta')]

    # 输出路径不存在就创建
    if not os.path.exists(FLAGS.output_model_dir):
        os.makedirs(FLAGS.output_model_dir)

    parameters = Params(eval_batch_size=128,
                        input_shape=(32, 304),
                        digits_only=False,
                        alphabet=Alphabet.CHINESECHAR_LETTERS_DIGITS_EXTENDED,
                        alphabet_decoding='same',
                        csv_delimiter=' ',
                        csv_files_eval=FLAGS.csv_files_eval,
                        output_model_dir=FLAGS.output_model_dir,
                        gpu=FLAGS.gpu
                        )

    model_params = {
        'Params': parameters,
    }

    os.environ['CUDA_VISIBLE_DEVICES'] = parameters.gpu
    config_sess = tf.ConfigProto()
    config_sess.gpu_options.per_process_gpu_memory_fraction = 0.6

    # Config estimator
    est_config = tf.estimator.RunConfig()
    est_config = est_config.replace(session_config=config_sess,
                                    save_summary_steps=100,
                                    model_dir=parameters.output_model_dir)

    estimator = tf.estimator.Estimator(model_fn=crnn_fn,
                                       params=model_params,
                                       config=est_config,
                                       model_dir=parameters.output_model_dir,
                                       )
    # Count number of image filenames in csv
    n_samples = 0
    for file in parameters.csv_files_eval:
        with open(file, mode='r', encoding='utf8') as csvfile:
            n_samples += len(csvfile.readlines())
    print(n_samples, np.floor(n_samples / parameters.eval_batch_size), parameters.eval_batch_size)
    try:
        with open(FLAGS.output_file, encoding='utf-8', mode='w') as save_file:
            for model in models_list:
                # eval_results = estimator.evaluate(input_fn=data_loader(csv_filename=parameters.csv_files_eval,
                #                                                        params=parameters,
                #                                                        batch_size=parameters.eval_batch_size,
                #                                                        num_epochs=1),
                #                                   steps=np.floor(n_samples / parameters.eval_batch_size),
                #                                   checkpoint_path=model)
                eval_results = estimator.evaluate(input_fn=input_fn(filename=parameters.csv_files_eval,
                                                                    is_training=False,
                                                                    params=parameters,
                                                                    batch_size=parameters.eval_batch_size,
                                                                    num_epochs=1),
                                                  steps=np.floor(n_samples / parameters.eval_batch_size),
                                                  checkpoint_path=model)
                print('model: %s Evaluation results: %s' % (model, str(eval_results)))
                save_file.write(model + ' ' + str(eval_results) + '\n')

    except KeyboardInterrupt:
        print('Interrupted')

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-fe', '--csv_files_eval', required=True, type=str, help='CSV filename for evaluation',
                        nargs='*', default=None)
    parser.add_argument('-o', '--output_model_dir', required=True, type=str,
                        help='Directory for output', default='./estimator')
    parser.add_argument('-m', '--input_model_dir', required=True, type=str,
                        help='Directory for output', default='./estimator')
    parser.add_argument('-g', '--gpu', type=str, help="GPU 0,1 or '' ", default='0')
    parser.add_argument('-of', '--output_file', required=True, type=str, help="the log output file")

    tf.logging.set_verbosity(tf.logging.DEBUG)
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

run script is

python3 test_model.py -fe /data/zj/data/val1.tfrecords -o result/model_test1 -m models_vgg_100K_no_eval/1 -g 3 -of test_model1.txt

the error is

INFO:tensorflow:Using config: {'_task_id': 0, '_num_worker_replicas': 1, '_num_ps_replicas': 0, '_session_config': gpu_options {
  per_process_gpu_memory_fraction: 0.6
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_master': '', '_tf_random_seed': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f5a50a2f898>, '_save_checkpoints_secs': 600, '_is_chief': True, '_save_checkpoints_steps': None, '_log_step_count_steps': 100, '_service': None, '_model_dir': 'result/model_test1', '_task_type': 'worker', '_save_summary_steps': 100}
10000 78.0 128
Traceback (most recent call last):
  File "test_model.py", line 103, in <module>
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    │               │           │              └ []
    │               │           └ <module 'sys' (built-in)>
    │               └ <function main at 0x7f5a94cd5f28>
    └ <module 'tensorflow' from '/usr/local/lib/python3.5/dist-packages/tensorflow/__init__.py'>
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/platform/app.py", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
    │         │    │               └ []
    │         │    └ <module 'sys' (built-in)>
    │         └ <function main at 0x7f5a94cd5f28>
    └ <module 'sys' (built-in)>
  File "test_model.py", line 82, in main
    checkpoint_path=model)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 355, in evaluate
    name=name)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 808, in _evaluate_model
    input_fn, model_fn_lib.ModeKeys.EVAL)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 577, in _get_features_and_labels_from_input_fn
    result = self._call_input_fn(input_fn, mode)
             │                   │         └ 'eval'
             │                   └ ({'images_widths': <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=int32>, 'images': <tf.Tensor 'IteratorGetNext:0' shape=(?, 32...
             └ <tensorflow.python.estimator.estimator.Estimator object at 0x7f5a50a2f6d8>
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/estimator.py", line 656, in _call_input_fn
    input_fn_args = util.fn_args(input_fn)
                    │            └ ({'images_widths': <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=int32>, 'images': <tf.Tensor 'IteratorGetNext:0' shape=(?, 32...
                    └ <module 'tensorflow.python.estimator.util' from '/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/util.py'>
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/estimator/util.py", line 57, in fn_args
    return tuple(tf_inspect.getargspec(fn).args)
                 │                     └ ({'images_widths': <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=int32>, 'images': <tf.Tensor 'IteratorGetNext:0' shape=(?, 32...
                 └ <module 'tensorflow.python.util.tf_inspect' from '/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/tf_inspect.py'>
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/util/tf_inspect.py", line 45, in getargspec
    if d.decorator_argspec is not None), _inspect.getargspec(target))
  File "/usr/lib/python3.5/inspect.py", line 1043, in getargspec
    getfullargspec(func)
    │              └ ({'images_widths': <tf.Tensor 'IteratorGetNext:1' shape=(?,) dtype=int32>, 'images': <tf.Tensor 'IteratorGetNext:0' shape=(?, 32...
    └ <function getfullargspec at 0x7f5a1855bf28>
  File "/usr/lib/python3.5/inspect.py", line 1095, in getfullargspec
    raise TypeError('unsupported callable') from ex
TypeError: unsupported callable
WenmuZhou commented 7 years ago

I have slove it

cipri-tom commented 6 years ago

you could also tell us how you solved it 😛

also, do you see any improvement from using tfrecord format ? is it running faster ?