Closed daoqinzi closed 6 years ago
import torch from torch.autograd import Variable import utils import dataset import os from PIL import Image
import models.crnn as crnn
model_path = './data/netCRNN_ch_nc_21_nh_128.pth' img_path = './data/image33.jpg' alphabet = u'\'ACIMRey万下依口哺摄次状璐癌草血运重'
nclass = len(alphabet) + 1
if torch.cuda.is_available(): model = crnn.CRNN(32, 1, nclass, 128).cuda() pre_model = torch.load(model_path) else: model = crnn.CRNN(32, 1, nclass, 128) pre_model = torch.load(model_path,map_location=lambda storage, loc: storage)
print('loading pretrained model from %s' % model_path) for k,v in pre_model.items(): print(k,len(v)) model.load_state_dict(pre_model)
converter = utils.strLabelConverter(alphabet)
transformer = dataset.resizeNormalize((100, 32)) image = Image.open(img_path).convert('L')
if torch.cuda.is_available(): image = transformer(image).cuda() else: image = transformer(image)
image = image.view(1, *image.size()) image = Variable(image)
model.eval() preds = model(image)
_, preds = preds.max(2) preds = preds.squeeze(2) preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)])) raw_pred = converter.decode(preds.data, preds_size.data, raw=True) sim_pred = converter.decode(preds.data, preds_size.data, raw=False) print('%-20s => %-20s' % (raw_pred.encode('utf8'), sim_pred.encode('utf8')))
@daoqinzi 查看crnn_main.py中有cuda的相关判断。 66行,eg:if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda")
if torch.cuda.is_available() and not opt.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda")
coding: utf-8
import torch from torch.autograd import Variable import utils import dataset import os from PIL import Image
import models.crnn as crnn
os.environ["CUDA_VISIBLE_DEVICES"] ="1"
model_path = './data/netCRNN_ch_nc_21_nh_128.pth' img_path = './data/image33.jpg' alphabet = u'\'ACIMRey万下依口哺摄次状璐癌草血运重'
print(alphabet)
nclass = len(alphabet) + 1
判断是否含有GPU
if torch.cuda.is_available(): model = crnn.CRNN(32, 1, nclass, 128).cuda() pre_model = torch.load(model_path) else: model = crnn.CRNN(32, 1, nclass, 128) pre_model = torch.load(model_path,map_location=lambda storage, loc: storage)
print('loading pretrained model from %s' % model_path) for k,v in pre_model.items(): print(k,len(v)) model.load_state_dict(pre_model)
converter = utils.strLabelConverter(alphabet)
transformer = dataset.resizeNormalize((100, 32)) image = Image.open(img_path).convert('L')
是否含有GPU
if torch.cuda.is_available(): image = transformer(image).cuda() else: image = transformer(image)
image = image.view(1, *image.size()) image = Variable(image)
model.eval() preds = model(image)
_, preds = preds.max(2) preds = preds.squeeze(2) preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)])) raw_pred = converter.decode(preds.data, preds_size.data, raw=True) sim_pred = converter.decode(preds.data, preds_size.data, raw=False) print('%-20s => %-20s' % (raw_pred.encode('utf8'), sim_pred.encode('utf8')))