Sierkinhane / CRNN_Chinese_Characters_Rec

(CRNN) Chinese Characters Recognition.
1.81k stars 537 forks source link

用作者的模型在360万数据集的测试集的36万多张图片上测试,val accuracy只有78.4% #136

Open Cocoalate opened 5 years ago

Cocoalate commented 5 years ago

由于作者的代码train和validation是在一起的,我把validation的部分提出来单独用作者训练好的模型对360万数据集的36万多张图片做validation,但是效果并不好,val accuracy只有78.4% image 作者说的验证准确率可以finetune到97.7%是指我需要在mixed_second_finetune_acc97p7.pth模型的基础行再进行finetune才能val accuracy才能从78.4%变成97.7%么? 或者是我validation出了问题?

from __future__ import print_function
from torch.utils.data import DataLoader
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import numpy as np
# from warpctc_pytorch import CTCLoss
import os
import utils
# import dataset
import models.crnn as crnn
import re
import params
from dataset_v2 import baiduDataset

# def init_args():
#     args = argparse.ArgumentParser()
#     args.add_argument('--trainroot', help='path to dataset', default='./to_lmdb/train')
#     args.add_argument('--valroot', help='path to dataset', default='./to_lmdb/train')
#     args.add_argument('--cuda', action='store_true', help='enables cuda', default=False)

#     return args.parse_args()

# custom weights initialization called on crnn

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def val(net, val_dataset, criterion):
    val_loader = DataLoader(val_dataset, batch_size=params.val_batchSize, shuffle=True, num_workers=params.workers)
    print('Start val')
    for p in crnn.parameters():
        p.requires_grad = False
    net.eval()
    i = 0
    n_correct = 0
    loss_avg = utils.averager()

    for i_batch, (image, index) in enumerate(val_loader):
        image = image.to(device)
        label = utils.get_batch_label(val_dataset, index)
        preds = crnn(image)
        batch_size = image.size(0)
        index = np.array(index.data.numpy())
        text, length = converter.encode(label)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)

        for pred, target in zip(sim_preds, label):
            if pred == target:
                n_correct += 1
        if (i_batch+1)%1000 == 0:
            print('[%d/%d]' %
                      (i_batch, len(val_loader)))
        # if i_batch == max_i:
        #    break
    raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:params.n_test_disp]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, label):
        print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
    print(n_correct)
    print(len(val_dataset))
    accuracy = n_correct / float(len(val_dataset))
    print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
    return accuracy

def main(crnn, val_dataset, criterion):
    crnn = crnn.to(device)
    criterion = criterion.to(device)
    accuracy = val(crnn, val_dataset, criterion)
    return accuracy

if __name__ == '__main__':
    manualSeed=10
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # alphabet = alphabet = utils.to_alphabet("H:/DL-DATASET/BaiduTextR/train.list")

    # store model path
    if not os.path.exists('./expr'):
        os.mkdir('./expr')
    # read train set
    val_dataset = baiduDataset("./rar/images", "./crnn_labels/test.txt.bk", params.alphabet, False, (params.imgW, params.imgH))
    # val_dataset = baiduDataset("/mnt/data/ocr/car_plates/val/CCPD2", "./CCPD2_label.txt", params.alphabet, False, (params.imgW, params.imgH))
    # val_dataset = baiduDataset("/mnt/data/ocr/car_plates/val/white", "./white_label.txt", params.alphabet, False, (params.imgW, params.imgH))
    converter = utils.strLabelConverter(val_dataset.alphabet)
    nclass = len(params.alphabet) + 1
    nc = 1
    # TODO why not mean
    criterion = torch.nn.CTCLoss(reduction='sum')
    # cnn and rnn
    crnn = crnn.CRNN(32, nc, nclass, params.nh)
    crnn.apply(weights_init)
    crnn.load_state_dict(torch.load("/home/keke/crnn_chinese_characters_rec/trained_models/mixed_second_finetune_acc97p7.pth"))
    # crnn.load_state_dict(torch.load("/home/keke/Lets_OCR/recognizer/crnn/w160_bs64_model/netCRNN_4_48000.pth"))
    if params.crnn != '':
        print('loading pretrained model from %s' % params.crnn)
        crnn.load_state_dict(torch.load(params.crnn))

    main(crnn, val_dataset, criterion)
psnow commented 4 years ago

由于作者的代码train和validation是在一起的,我把validation的部分提出来单独用作者训练好的模型对360万数据集的36万多张图片做validation,但是效果并不好,val accuracy只有78.4% image 作者说的验证准确率可以finetune到97.7%是指我需要在mixed_second_finetune_acc97p7.pth模型的基础行再进行finetune才能val accuracy才能从78.4%变成97.7%么? 或者是我validation出了问题?

from __future__ import print_function
from torch.utils.data import DataLoader
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import numpy as np
# from warpctc_pytorch import CTCLoss
import os
import utils
# import dataset
import models.crnn as crnn
import re
import params
from dataset_v2 import baiduDataset

# def init_args():
#     args = argparse.ArgumentParser()
#     args.add_argument('--trainroot', help='path to dataset', default='./to_lmdb/train')
#     args.add_argument('--valroot', help='path to dataset', default='./to_lmdb/train')
#     args.add_argument('--cuda', action='store_true', help='enables cuda', default=False)

#     return args.parse_args()

# custom weights initialization called on crnn

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def val(net, val_dataset, criterion):
    val_loader = DataLoader(val_dataset, batch_size=params.val_batchSize, shuffle=True, num_workers=params.workers)
    print('Start val')
    for p in crnn.parameters():
        p.requires_grad = False
    net.eval()
    i = 0
    n_correct = 0
    loss_avg = utils.averager()

    for i_batch, (image, index) in enumerate(val_loader):
        image = image.to(device)
        label = utils.get_batch_label(val_dataset, index)
        preds = crnn(image)
        batch_size = image.size(0)
        index = np.array(index.data.numpy())
        text, length = converter.encode(label)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        loss_avg.add(cost)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        sim_preds = converter.decode(preds.data, preds_size.data, raw=False)

        for pred, target in zip(sim_preds, label):
            if pred == target:
                n_correct += 1
        if (i_batch+1)%1000 == 0:
            print('[%d/%d]' %
                      (i_batch, len(val_loader)))
        # if i_batch == max_i:
        #    break
    raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:params.n_test_disp]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, label):
        print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
    print(n_correct)
    print(len(val_dataset))
    accuracy = n_correct / float(len(val_dataset))
    print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
    return accuracy

def main(crnn, val_dataset, criterion):
    crnn = crnn.to(device)
    criterion = criterion.to(device)
    accuracy = val(crnn, val_dataset, criterion)
    return accuracy

if __name__ == '__main__':
    manualSeed=10
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    cudnn.benchmark = True
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # alphabet = alphabet = utils.to_alphabet("H:/DL-DATASET/BaiduTextR/train.list")

    # store model path
    if not os.path.exists('./expr'):
        os.mkdir('./expr')
    # read train set
    val_dataset = baiduDataset("./rar/images", "./crnn_labels/test.txt.bk", params.alphabet, False, (params.imgW, params.imgH))
    # val_dataset = baiduDataset("/mnt/data/ocr/car_plates/val/CCPD2", "./CCPD2_label.txt", params.alphabet, False, (params.imgW, params.imgH))
    # val_dataset = baiduDataset("/mnt/data/ocr/car_plates/val/white", "./white_label.txt", params.alphabet, False, (params.imgW, params.imgH))
    converter = utils.strLabelConverter(val_dataset.alphabet)
    nclass = len(params.alphabet) + 1
    nc = 1
    # TODO why not mean
    criterion = torch.nn.CTCLoss(reduction='sum')
    # cnn and rnn
    crnn = crnn.CRNN(32, nc, nclass, params.nh)
    crnn.apply(weights_init)
    crnn.load_state_dict(torch.load("/home/keke/crnn_chinese_characters_rec/trained_models/mixed_second_finetune_acc97p7.pth"))
    # crnn.load_state_dict(torch.load("/home/keke/Lets_OCR/recognizer/crnn/w160_bs64_model/netCRNN_4_48000.pth"))
    if params.crnn != '':
        print('loading pretrained model from %s' % params.crnn)
        crnn.load_state_dict(torch.load(params.crnn))

    main(crnn, val_dataset, criterion)

能分享一下你的测试代码吗?

Cocoalate commented 4 years ago

@psnow 你是说在整个test set上面进行验证的代码还是?我发的代码就是对整个test set进行验证的,你改一下后面的数据路径就可以了

ws-lin commented 4 years ago

我也测试了 准确率只要70%多,我怀疑是数据标签有问题。

yedaorman commented 4 years ago

我也测试了只用70%多。但是使用作者给出的数据集进行训练与测试可以达到90%多。

yedaorman commented 4 years ago

不清楚是哪里的问题

stringk245 commented 4 years ago

不清楚是哪里的问题

作者的数据集精度能到90%多,自己的数据集只有70%多嘛 是 自己标注问题还是 网络问题?还是图片预处理导致? 有解决吗?

jasnei commented 3 years ago

@Cocoalate 从作者的代码看,validation是对测试集的1000作测试,这个准确率也应该是这1000个图像的准确率,而非整个测试集的准确率,我也从0训练了10个epoch就能达了96.59,但这个分数也只是测试集里1000张图像的里的准确率