wushilian / CRNN_Attention_OCR_Chinese

CRNN with attention to do OCR,add Chinese recognition
334 stars 114 forks source link

训练结果总是 '的' #15

Closed badbubble closed 6 years ago

badbubble commented 6 years ago

您好, 我按照train.txt生成了数据,然后修改了一下代码,但是发现预测结果总是 '的', 是因为训练的次数不够多 还是我改的代码有问题?

[+] epoch:0, batch:371, loss:3.7565720081329346, acc:0.0,
 train_decode:[',的的的的的的的的的', ',,的的的的的的的的', ',,的的的的的的的的', ',,的的的的的的的的', ',的的的的的的的的的'], 
 val_decode:[',的的的的的的的的的', ',的的的的的的的的的', ',的的的的的的的的的', ',的的的的的', ',,的的的的的的的的'], 
 val_truth:['管理我们要质疑!当然', '那大仙才能保持人体健', '种色彩因其波长不同而', '命和民主共和制的产生', '盟者将按照年营业额的']

config.py 主要是加了一个数据批量读入的功能 还把图片尺寸改成280*32了 不知道有没有影响

import numpy as np
import cv2
import os

learning_rate = 0.001
momentum = 0.9
START_TOKEN = 0
END_TOKEN = 1
UNK_TOKEN = 2
VOCAB = {'<GO>': 0, '<EOS>': 1, '<UNK>': 2, '<PAD>': 3}  # 分别表示开始,结束,未出现的字符
VOC_IND = {}

# charset='0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
def get_class(path):
    """

    :param path:
    :return: 文字到int , int到文字
    """
    f = open(path, 'r', encoding='UTF-8')
    line = f.readline().strip()
    class2int = {}
    int2class = {}
    i = 0
    while line != '':
        class2int[line] = i
        int2class[i] = line
        line = f.readline().strip()
        i = i + 1
    f.close()
    return class2int, int2class

_, charset = get_class('char_std_5990.txt')

for i in range(len(charset)):
    VOCAB[charset[i]] = i + 4
for key in VOCAB:
    VOC_IND[VOCAB[key]] = key

NUM_BATCHES = 0
MAX_LEN_WORD = 20  # 标签的最大长度,以PAD
VOCAB_SIZE = len(VOCAB)
BATCH_SIZE = 320
RNN_UNITS = 256
EPOCH = 10000
IMAGE_WIDTH = 280
IMAGE_HEIGHT = 32
MAXIMUM__DECODE_ITERATIONS = 20
DISPLAY_STEPS = 20
LOGS_PATH = 'log'
CKPT_DIR = 'save_model'
train_dir = '/media/lhj/0C3313300C331330/Images'
val_dir = 'data/data'
is_restore = True

def label2int(label):  # label shape (num,len)
    # seq_len=[]
    target_input = np.ones((len(label), MAX_LEN_WORD), dtype=np.float32) + 2  # 初始化为全为PAD
    target_out = np.ones((len(label), MAX_LEN_WORD), dtype=np.float32) + 2  # 初始化为全为PAD
    for i in range(len(label)):
        # seq_len.append(len(label[i]))
        target_input[i][0] = 0  # 第一个为GO
        for j in range(len(label[i])):
            target_input[i][j + 1] = VOCAB[label[i][j]]
            target_out[i][j] = VOCAB[label[i][j]]
        target_out[i][len(label[i])] = 1
    return target_input, target_out

def int2label(decode_label):
    label = []
    for i in range(decode_label.shape[0]):
        temp = ''
        for j in range(decode_label.shape[1]):
            if VOC_IND[decode_label[i][j]] == '<EOS>':
                break
            elif decode_label[i][j] == 3:
                continue
            else:
                temp += VOC_IND[decode_label[i][j]]
        label.append(temp)
    return label

def batch_iter(data_dir, file):
    """生成批次数据"""
    global NUM_BATCHES
    f = open(file, 'r', encoding='UTF-8')
    lines = f.read().strip().split('\n')

    data_len = len(lines)
    NUM_BATCHES = int((data_len - 1) / BATCH_SIZE) + 1

    print("[+] You have {} batches!".format(NUM_BATCHES))

    for i in range(NUM_BATCHES):
        image = []
        labels = []
        start_id = i * BATCH_SIZE
        end_id = min((i + 1) * BATCH_SIZE, data_len)
        iter_lines = lines[start_id:end_id]
        for line in iter_lines:
            s = line.strip().split(' ')
            label = ''
            image_name = os.path.join(data_dir, s[0])
            im = cv2.imread(image_name, 0)  # /255.#read the gray image
            if im.shape != [32, 280]:
                im = cv2.resize(im, (IMAGE_WIDTH, IMAGE_HEIGHT))
            img = im.swapaxes(0, 1)
            image.append(np.array(img[:, :, np.newaxis]))
            for i in range(len(s) - 1):
                label += charset[int(s[i + 1])]
            labels.append(label)
        yield np.array(image), labels

def cal_acc(pred, label):
    num = 0
    for i in range(len(pred)):
        if pred[i] == label[i]:
            num += 1
    return num * 1.0 / len(pred)

train.py 调了一下代码结构, 把批量读入的函数加进去了

from model import *
import config as cfg
import time
import os

init = tf.global_variables_initializer()  # 变量初始化

with tf.name_scope("optimizer") as scope:
    loss, train_decode_result, pred_decode_result = build_network(is_training=True)  # TODO
    optimizer = tf.train.MomentumOptimizer(learning_rate=cfg.learning_rate, momentum=cfg.momentum, use_nesterov=True)
    train_op = optimizer.minimize(loss)

with tf.name_scope('summaries'):
    saver = tf.train.Saver(max_to_keep=5)
    tf.summary.scalar("cost", loss)
    summary_op = tf.summary.merge_all()
    writer = tf.summary.FileWriter(cfg.LOGS_PATH)

with tf.Session() as sess:
    tf.initialize_all_variables().run()
    sess.run(init)
    if cfg.is_restore:
        ckpt = tf.train.latest_checkpoint(cfg.CKPT_DIR)
        if ckpt:
            saver.restore(sess, ckpt)
            print('[*] restore from the checkpoint{0}'.format(ckpt))

    num_batches_per_epoch = cfg.NUM_BATCHES
    for cur_epoch in range(cfg.EPOCH):
        train_dataloader = cfg.batch_iter(cfg.train_dir, 'data/train.txt')
        test_dataloader = cfg.batch_iter(cfg.train_dir, 'data/my_test')

        train_cost = 0
        start_time = time.time()
        batch_time = time.time()
        val_img, val_label = next(test_dataloader)
        # the tracing part
        for cur_batch in range(10249):
            batch_time = time.time()
            batch_inputs, batch_label = next(train_dataloader)
            print("Batch Label: ", batch_label)
            batch_target_in, batch_target_out = cfg.label2int(batch_label)
            sess.run(train_op,
                     feed_dict={image: batch_inputs, train_output: batch_target_in, target_output: batch_target_out,
                                sample_rate: np.min([1., 0.2 * cur_epoch + 0.2])})

            if cur_batch % 1 == 0:
                summary_loss, loss_result = sess.run([summary_op, loss],
                                                     feed_dict={image: batch_inputs, train_output: batch_target_in,
                                                                target_output: batch_target_out,
                                                                sample_rate: np.min([1., 1.])})
                writer.add_summary(summary_loss, cur_epoch * num_batches_per_epoch + cur_batch)
                val_predict = sess.run(pred_decode_result, feed_dict={image: val_img[0:cfg.BATCH_SIZE]})
                train_predict = sess.run(pred_decode_result, feed_dict={image: batch_inputs,
                                                                        train_output: batch_target_in,
                                                                        target_output: batch_target_out,
                                                                        sample_rate: np.min([1., 1.])})
                predit = cfg.int2label(np.argmax(val_predict, axis=2))
                train_pre = cfg.int2label(np.argmax(train_predict, axis=2))
                gt = val_label[0:cfg.BATCH_SIZE]
                acc = cfg.cal_acc(predit, gt)

                print("[+] epoch:{}, batch:{}, loss:{}, acc:{},\n train_decode:{}, \n val_decode:{}, \n val_truth:{}".
                      format(cur_epoch, cur_batch,
                             loss_result, acc,
                             train_pre[0:5],
                             predit[0:5],
                             gt[0:5]))

                if not os.path.exists(cfg.CKPT_DIR):
                    os.makedirs(cfg.CKPT_DIR)
                saver.save(sess, os.path.join(cfg.CKPT_DIR, 'attention_ocr.model'),
                           global_step=cur_epoch * num_batches_per_epoch + cur_batch)
wushilian commented 6 years ago

@ETCartman 训练前期会出现这种 情况,而且这个误差还很大,一般误差要降到0.0x才算训练的比较好

badbubble commented 6 years ago

@wushilian 好的,谢谢 我继续训练

badbubble commented 6 years ago

@wushilian 训练好慢啊... 您之前有训练过吗? gtx1070训练了快6个小时 按照这个趋势1个epoch能降到2.XX 您也这样吗? loss

wushilian commented 6 years ago

@ETCartman 你的学习率是多少,我1080ti训了2,3天

badbubble commented 6 years ago

@wushilian 0.001

wushilian commented 6 years ago

@ETCartman 学习率有点大,可能会发散,改成1e-4吧

badbubble commented 6 years ago

@wushilian 好的

zhangtao22 commented 6 years ago

我也改的跟cartman差不多,用的训练集应该也是一个,训了两天,马上结束第二个epoch,loss在2.5左右就不降了,学习率0.0001

zhangtao22 commented 6 years ago

@wushilian 你最终的loss是多少,这个中文的5990

wushilian commented 6 years ago

@zhangtao22 你用的什么优化器,我最终的误差是0.0x

zhangtao22 commented 6 years ago

就是你的代码,优化器都没有改,就把学习率改成了0.0001.你也是训的5990这个中文嘛?

zhangtao22 commented 6 years ago

@ETCartman 你这个loss现在降到多少了?

wushilian commented 6 years ago

@zhangtao22 换成adam试试

zhangtao22 commented 6 years ago

你adam用的参数是多少,谢谢

wushilian commented 6 years ago

默认参数

zhangtao22 commented 6 years ago

learning_rate是0.0001?

zhangtao22 commented 6 years ago

这个infer的时候图片大小必须和训练时候一样尺寸嘛?我送272*32的进去报了一大堆错误 2018-04-23 15:11:40.398022: W tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: ConcatOp : Dimensions of inputs should match: shape[0] = [64,6250] vs. shape[1] = [1,256] [[Node: decode_1/decoder/while/BasicDecoderStep/decoder/attention_wrapper/attention_wrapper/lstm_cell/lstm_cell/lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/gpu:0"](decode_1/decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat, decode_1/decoder/while/Identity_4, decode_1/decoder/while/BasicDecoderStep/decoder/attention_wrapper/attention_wrapper/lstm_cell/lstm_cell/lstm_cell/concat/axis)]] Traceback (most recent call last): File "infer.py", line 22, in val_predict = sess.run(pred_decode_result,feed_dict={image: val_img}) File "/home/zhangtao/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 895, in run run_metadata_ptr) File "/home/zhangtao/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1124, in _run feed_dict_tensor, options, run_metadata) File "/home/zhangtao/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run options, run_metadata) File "/home/zhangtao/.local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1340, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.InvalidArgumentError: ConcatOp : Dimensions of inputs should match: shape[0] = [64,6250] vs. shape[1] = [1,256] [[Node: decode_1/decoder/while/BasicDecoderStep/decoder/attention_wrapper/attention_wrapper/lstm_cell/lstm_cell/lstm_cell/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/gpu:0"](decode_1/decoder/while/BasicDecoderStep/decoder/attention_wrapper/concat, decode_1/decoder/while/Identity_4, decode_1/decoder/while/BasicDecoderStep/decoder/attention_wrapper/attention_wrapper/lstm_cell/lstm_cell/lstm_cell/concat/axis)]] [[Node: decode_1/decoder/transpose/_157 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_849_decode_1/decoder/transpose", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Caused by op u'decode_1/decoder/while/BasicDecoderStep/decoder/attention_wrapper/attention_wrapper/lstm_cell/lstm_cell/lstm_cell/concat', defined at:

badbubble commented 6 years ago

@zhangtao22 我没跑, 今跑一下, 明天看一下结果

yourlovedu commented 6 years ago

@zhangtao22 你这个问题是由于forward的时候BATCH_SIZE还是64,但是你只输入了一张图片的原因,如果每次直接输入64张图片就不会报错了,会输出64个预测结果。但是我不知道如何只输入一张图片预测,我把问题贴在这里https://github.com/wushilian/CRNN_Attention_OCR_Chinese/issues/17 了,希望大家可以帮忙解答一下,谢谢

zhangtao22 commented 6 years ago

@yourlovedu 谢谢,这种。。。

yourlovedu commented 6 years ago

@zhangtao22 不客气,那么怎么单张图片预测呢?有没有什么好的办法 @wushilian