hezarai / hezar

The all-in-one AI library for Persian, supporting a wide variety of tasks and modalities!
https://hezarai.github.io/hezar/
Apache License 2.0
839 stars 45 forks source link

convert model to ONNX file #160

Closed davoodap closed 4 months ago

davoodap commented 4 months ago

Hello I want to download the license plate recognition model and then convert the model file to ONNX format and use it locally. Is this possible? Please guide me if possible. Thanks to the creators of this functional library.

arxyzan commented 4 months ago

Hi @davoodap, thanks for your feedback. It's definitely possible to convert Hezar models to ONNX since all models in Hezar are PyTorch nn.Module subclasses so you can convert them as you would for any PyTorch model. The only challenge here is that when you call the .predict(**inputs) method in Hezar, a three-step pipeline is executed:

  1. Model.preprocess(inputs, **kwargs)
  2. Model.forward(inputs, **kwargs)
  3. Model.postprocess(intputs, **kwarg)

ONNX can only serialize the forward method so that you would need to implement the preprocess and post processing yourself. You can inspect these methods in the model file (hezar.models.image2text.crnn.crnn_image2text.py in your case). It's not that complicated but it means that you cannot get the benefit of the .predict() abstractions (and other stuff specific to Hezar's Model class) when the model is converted to ONNX.

davoodap commented 4 months ago

@arxyzan Thank you for your quick response. Can you show me the steps to do this conversion? I am new in this field

arxyzan commented 4 months ago

@davoodap Yes of course, I will write an end-to-end script to do so, but I'm afraid you might have to give me some time. I'm really loaded with work right now. I try to do it in a day or two if that's okay with you.

davoodap commented 4 months ago

@arxyzan OK, i am waiting for your response

arxyzan commented 4 months ago

@davoodap You can follow below two steps to do so:

Step 1: Convert to ONNX

import numpy as np
import torch
import onnxruntime
from hezar.models import Model

model_id = "hezarai/crnn-fa-64x256-license-plate-recognition"
onnx_path = "crnn-fa-alpr.onnx"

model = Model.load(model_id)
model.eval()

dummy_inputs = torch.randn(1, 1, 64, 256)

with torch.inference_mode():
    outputs = model(dummy_inputs)

torch.onnx.export(
    model,
    dummy_inputs,
    onnx_path,
    export_params=True,
    input_names=["inputs"],
    output_names=["outputs"],
)

# Make sure the saved ONNX model gives the same results
ort_session = onnxruntime.InferenceSession(onnx_path)

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: dummy_inputs.numpy()}
ort_outputs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(outputs["logits"].numpy(), ort_outputs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Step 2: Run inference and validate the results

import numpy as np
import torch
import onnxruntime
from hezar.models import ModelConfig
from hezar.models.model_outputs import Image2TextOutput
from hezar.models.image2text.crnn.crnn_decode_utils import ctc_decode
from hezar.preprocessors import ImageProcessor
from hezar.utils import reverse_string_digits

model_config = ModelConfig.load(model_id, filename="model_config.yaml")
input_image = "../assets/license_plate_ocr_example.jpg"

# Preprocess input
preprocessor = ImageProcessor.load(model_id)
inputs = preprocessor([input_image])["pixel_values"]

model = onnxruntime.InferenceSession(onnx_path)

# ONNX inference
inputs = {model.get_inputs()[0].name: inputs.numpy()}
logits = torch.tensor(model.run(None, inputs)[0])

# Generation
generated_ids = ctc_decode(logits, blank=model_config.blank_id)
probs, values = logits.permute(1, 0, 2).softmax(2).max(2)
scores = probs.mean(1)

# Post-process
outputs = []
generated_ids = generated_ids.cpu().numpy().tolist()
scores = scores.cpu().numpy().tolist()
for decoded_ids, score in zip(generated_ids, scores):
    chars = [model_config.id2label[id_] for id_ in decoded_ids]
    text = "".join(chars)
    if model_config.reverse_output_digits:
        text = reverse_string_digits(text)
    outputs.append(Image2TextOutput(text=text, score=score))

print(outputs)

Note that this code is the most minimal code for this challenge. This might seem a little long but it's actually what Hezar does under the hood to make everything abstract and by converting to ONNX all the preprocessing and post-processing must be handled explicitly and there's no other way around it :((((

davoodap commented 4 months ago

@arxyzan Thank you for helping me. They were run in Windows 10 and Python 3.11 without any problems. Thank you again for your time

arxyzan commented 4 months ago

@davoodap Nice😉