tongpi / basicOCR

BasicOCR是一个致力于解决自然场景文字识别算法研究的项目。该项目由长城数字大数据应用技术研究院佟派AI团队发起和维护。
https://tongpi.github.io/basicOCR/
GNU General Public License v3.0
342 stars 127 forks source link

修改demo代码,使之适用于cpu和gpu #25

Closed daoqinzi closed 6 years ago

daoqinzi commented 6 years ago

coding: utf-8

import torch from torch.autograd import Variable import utils import dataset import os from PIL import Image

import models.crnn as crnn

os.environ["CUDA_VISIBLE_DEVICES"] ="1"

model_path = './data/netCRNN_ch_nc_21_nh_128.pth' img_path = './data/image33.jpg' alphabet = u'\'ACIMRey万下依口哺摄次状璐癌草血运重'

print(alphabet)

nclass = len(alphabet) + 1

判断是否含有GPU

if torch.cuda.is_available(): model = crnn.CRNN(32, 1, nclass, 128).cuda() pre_model = torch.load(model_path) else: model = crnn.CRNN(32, 1, nclass, 128) pre_model = torch.load(model_path,map_location=lambda storage, loc: storage)

print('loading pretrained model from %s' % model_path) for k,v in pre_model.items(): print(k,len(v)) model.load_state_dict(pre_model)

converter = utils.strLabelConverter(alphabet)

transformer = dataset.resizeNormalize((100, 32)) image = Image.open(img_path).convert('L')

是否含有GPU

if torch.cuda.is_available(): image = transformer(image).cuda() else: image = transformer(image)

image = image.view(1, *image.size()) image = Variable(image)

model.eval() preds = model(image)

_, preds = preds.max(2) preds = preds.squeeze(2) preds = preds.transpose(1, 0).contiguous().view(-1)

preds_size = Variable(torch.IntTensor([preds.size(0)])) raw_pred = converter.decode(preds.data, preds_size.data, raw=True) sim_pred = converter.decode(preds.data, preds_size.data, raw=False) print('%-20s => %-20s' % (raw_pred.encode('utf8'), sim_pred.encode('utf8')))

YoungMiao commented 6 years ago

@daoqinzi 查看crnn_main.py中有cuda的相关判断。 66行,eg:if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda")