roatienza / deep-text-recognition-benchmark

PyTorch code of my ICDAR 2021 paper Vision Transformer for Fast and Efficient Scene Text Recognition (ViTSTR)
Apache License 2.0
293 stars 59 forks source link

ONNX #35

Open centurions opened 2 years ago

centurions commented 2 years ago

Hi thank you for your great work . Would you please add code for converting pth to onnx ?

roatienza commented 2 years ago

there is a good guide for converting torch to onnx. https://pytorch.org/docs/stable/onnx.html

centurions commented 2 years ago

Thank you . I have successfully converted pth to onnx . I get this error when running this ineferce code . Would you please help me to fix it ? Error

Traceback (most recent call last): File "inference_onnx.py", line 45, in pred_str = converter.decode(pred_index, length_for_pred) File "E:\new2\text-recognition-wii-main\preprocess\converter.py", line 63, in decode text = ''.join([self.character[i] for i in text_index[index, :]]) File "E:\new2\text-recognition-wii-main\preprocess\converter.py", line 63, in text = ''.join([self.character[i] for i in text_index[index, :]]) TypeError: list indices must be integers or slices, not numpy.float32

Inference code

%%

import onnxruntime as ort from PIL import Image from torchvision import transforms import numpy as np import matplotlib.pyplot as plt import base64 import io import torch from preprocess.converter import NormalizePAD, TokenLabelConverter from models.model import Model

%%

def preprocess_img(path): data_transforms = NormalizePAD((1, 224, 224)) img = Image.open(path).convert('L') img = img.resize((224, 224), Image.BICUBIC) img = data_transforms(img) img = img.unsqueeze(0) return img

def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def predict_onnx(sess, path): input_name = sess.get_inputs()[0].name print(input_name) img = preprocess_img(path) preds = sess.run(None, {input_name: to_numpy(img)}) preds = np.squeeze(preds) return preds

%%

if name == 'main': sess = ort.InferenceSession('last_model.onnx') converter = TokenLabelConverter() path = 'examples/1.jpg' preds = predict_onnx(sess, path) print(preds)

%%

print(preds.shape)
pred_index = preds.max(1)
pred_index = pred_index.reshape(1, 25)
print(pred_index)
length_for_pred = np.array([25 - 1])
pred_str = converter.decode(pred_index, length_for_pred)
print(pred_str)

%%

%%