felixdittrich92 / OnnxTR

OnnxTR a docTR (Document Text Recognition) library Onnx pipeline wrapper - for seamless, high-performing & accessible OCR
https://github.com/mindee/doctr
Apache License 2.0
13 stars 0 forks source link

[Bug]: Different results between doctr and onnxtr models. #10

Closed decadance-dance closed 23 hours ago

decadance-dance commented 1 month ago

Bug description

It is expected that converting to onnx will not affect the results. In my case (fast_base usage) this is not true and the onnx model works worse due to the fact that the regions of the probability mask are consistently narrower and sometimes this leads to artifacts.

Code snippet to reproduce the bug

import cv2
from doctr.models import detection_predictor
from onnxtr.models import detection_predictor as detection_predictor_onnx
from matplotlib import pyplot as plt

def plot(image,si=[12,12]):
    fig, ax = plt.subplots(figsize=si);ax.imshow(image,cmap='gray')
    ax.get_xaxis().set_visible(False);ax.get_yaxis().set_visible(False)
    plt.show()  

if __name__ == "__main__":

    model = detection_predictor(
        arch='fast_base', 
        pretrained=True,
        pretrained_backbone=True,
    ).cuda().half()

    model_onnx = detection_predictor_onnx(arch='fast_base')

    img = cv2.imread("1.jpg")

    prob_map = model([img], return_maps=True)[1][0]
    prob_map_onnx = model_onnx([img], return_maps=True)[1][0]

    plot(prob_map)
    plot(prob_map_onnx)

Image: 1

Error traceback

Here I provided different results both prob and bbox levels: onnx bboxes: onnx doctr bboxes: doctr

onnx prob map: onnx_probmap doctr prob map: doctr_probmap

Environment

Python 3.10.13 Ubuntu 20.04.5 LTS

onnxruntime 1.17.0 onnxruntime-gpu 1.17.1 onnxtr 0.1.2a0

felixdittrich92 commented 1 month ago

Hi @decadance-dance 👋, Thanks for the detailed report.

First a question: Are the namings of the Images are correct ? Because it looks like the additional artefact in the middle image of the page occurs in the doctr output but not in the OnnxTR output!?

The issue comes from the preprocessing. To explain:

docTR (pytorch): Read Image (opencv) -> Resize (torchvision/ PIL under the hood / antialiasing by default)

OnnxTR: Read Image (opencv) -> Resize (opencv / antialiasing not available - only close interpolation options are available)

We could test some interpolations to get a closer result but as a side effect this will slightly slowdown the pipe :)

felixdittrich92 commented 1 month ago

Ok i checked this also by changing the preproc to torch and it seems you are right - the pipe is fine it comes from the converted model i will check the diff again :+1:

By using the ocr_predictor

                     OnnxTR                                                 docTR

Screenshot from 2024-05-15 09-53-13

from onnxtr.io import DocumentFile
from onnxtr.models import detection_predictor as onnxtr_detection_predictor
from doctr.models import detection_predictor as doctr_detection_predictor
import numpy as np

doc = DocumentFile.from_images(["/home/felix/Desktop/doctr_test_data/test_page.jpg"])

# DOCTR
doctr_model = doctr_detection_predictor(arch="fast_base", pretrained=True)
res_doctr = doctr_model(doc, return_maps=True)[1][0]
print(res_doctr.shape)
# (1024, 1024, 1)

# ONNXTR
onnxtr_model = onnxtr_detection_predictor()
res_onnxtr = onnxtr_model(doc, return_maps=True)[1][0]
print(res_onnxtr.shape)
# (1024, 1024, 1)

# compute mean difference
mean_diff = np.abs(np.mean(res_doctr - res_onnxtr))
print(f"Mean difference between DOCTR and ONNXTR prob maps: {mean_diff:.8f}")
# Mean difference between DOCTR and ONNXTR: 0.07267220
# Model logits check
fast_onnxtr = onnxtr_fast_base()
fast_doctr = doctr_fast_base(pretrained=True, exportable=True).eval()

input_tensor = np.random.rand(1, 3, 1024, 1024).astype(np.float32)
torch_input_tensor = torch.from_numpy(input_tensor).float()

onnxtr_out = fast_onnxtr(input_tensor)
onnxtr_out = np.moveaxis(onnxtr_out["logits"], -1, 0)  # logits output was added for testing purposes - not available in onnxtr
print(onnxtr_out.shape)
doctr_out = fast_doctr(torch_input_tensor)
doctr_out = doctr_out["logits"].detach().numpy()

# compute mean difference
mean_diff = np.abs(np.mean(doctr_out - onnxtr_out))
print(f"Mean difference between DOCTR and ONNXTR logits: {mean_diff:.8f}")
# Mean difference between DOCTR and ONNXTR logits: 0.00002960
felixdittrich92 commented 1 month ago

Also compared other models:

DB resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00546016
Mean difference between DOCTR and ONNXTR logits: 0.00000007

Linknet resnet18
Mean difference between DOCTR and ONNXTR prob maps: 0.00332537
Mean difference between DOCTR and ONNXTR logits: 0.00000158

Compared to FAST models the diff is much smaller :+1: I don't think that i can do much about this yet but i will have an eye on it with the next model iteration so let's keep this issue open :)

milosacimovic commented 4 days ago

Having issues as well with this. Results are vastly different, but I can confirm it's not due to the exported models because I have my own implementation of running ONNX DocTR exports which works fine.

milosacimovic commented 3 days ago

Is there a workaround for this?

felixdittrich92 commented 3 days ago

Hi @milosacimovic :wave:,

Which models do you use ? Exported from docTR v0.8.1 or v0.9.0a0 (main branch) ?

milosacimovic commented 3 days ago

They have been exported from docTR v0.8.1

milosacimovic commented 3 days ago

It seems that in ONNXTR there are three things that are different in preprocessing.

docTR:

  1. One resize
  2. Checking for symmetric pad and adjusting for it

onnxTR:

  1. Two resizes (different interpolation)
  2. Not checking for symmetric padding (this impacts recognition since it does not use symmetric padding)

There also might be a bug/difference in how the initial resize (prior padding) is done

milosacimovic commented 1 day ago

Another thing I noticed is different in onnxtr vs doctr are the per postprocessor bin_thresh and box_thresh. For OnnxTR the default of 0.5 is used for both while in doctr e.g. LinkNetPostProcessor has 0.1 used as default for both.

felixdittrich92 commented 1 day ago

Hi @milosacimovic :wave:

First you are right there was a bug in resizing: I opened a PR would you like to test it ? (https://github.com/felixdittrich92/OnnxTR/pull/22)

bin + box thresh values are directly set in the model files (same as in doctr) https://github.com/felixdittrich92/OnnxTR/blob/03a12772fa991557aeb151e93a4530c9153ef95a/onnxtr/models/detection/models/linknet.py#L60

:)

milosacimovic commented 1 day ago

Hi @felixdittrich92 my bad about bin_thresh, you're right, they're fine

felixdittrich92 commented 23 hours ago

Looks much closer with the PR now:

Overview:

Before:

(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_resnet34
Mean difference between DOCTR and ONNXTR prob maps: 0.00266014
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00546016
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_mobilenet_v3_large
Mean difference between DOCTR and ONNXTR prob maps: 0.00505121
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet18
Mean difference between DOCTR and ONNXTR prob maps: 0.00332537
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet34
Mean difference between DOCTR and ONNXTR prob maps: 0.00293290
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00237311
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_tiny
Mean difference between DOCTR and ONNXTR prob maps: 0.07307008
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_small
Mean difference between DOCTR and ONNXTR prob maps: 0.07393986
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_base
Mean difference between DOCTR and ONNXTR prob maps: 0.07267220

After:

(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_resnet34
Mean difference between DOCTR and ONNXTR prob maps: 0.00000007
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00005615
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_mobilenet_v3_large
Mean difference between DOCTR and ONNXTR prob maps: 0.00003883
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet18
Mean difference between DOCTR and ONNXTR prob maps: 0.00000575
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet34
Mean difference between DOCTR and ONNXTR prob maps: 0.00002059
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00000474
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_tiny
Mean difference between DOCTR and ONNXTR prob maps: 0.07032311
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_small
Mean difference between DOCTR and ONNXTR prob maps: 0.07004898
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_base
Mean difference between DOCTR and ONNXTR prob maps: 0.07088637

The diff for the fast models seems to come from the export.

felixdittrich92 commented 23 hours ago

Unfortunately this destroys the idea to drop PIL as dependency because before it was only used to get the font now it's also needed for the preproc to overcome the interpolation diff between PIL and CV2 :sweat_smile:

felixdittrich92 commented 23 hours ago

@decadance-dance @milosacimovic Could you test it again please from main branch ? :)

As mentioned the diff with fast seems to come mainly from the export, unfortunately i can't do anything before we train this model again

pip3 install onnxtr[cpu]@git+https://github.com/felixdittrich92/OnnxTR.git
milosacimovic commented 21 hours ago

Works better thanks @felixdittrich92 , but still having some issue that I am having a problem to nail down which part is making the difference, the inputs are exactly the same but somehow, I'm getting different results.

felixdittrich92 commented 21 hours ago

@milosacimovic great 👍 Which model architectures do you use ?