Closed decadance-dance closed 23 hours ago
Hi @decadance-dance 👋, Thanks for the detailed report.
First a question: Are the namings of the Images are correct ? Because it looks like the additional artefact in the middle image of the page occurs in the doctr output but not in the OnnxTR output!?
The issue comes from the preprocessing. To explain:
docTR (pytorch): Read Image (opencv) -> Resize (torchvision/ PIL under the hood / antialiasing by default)
OnnxTR: Read Image (opencv) -> Resize (opencv / antialiasing not available - only close interpolation options are available)
We could test some interpolations to get a closer result but as a side effect this will slightly slowdown the pipe :)
Ok i checked this also by changing the preproc to torch and it seems you are right - the pipe is fine it comes from the converted model i will check the diff again :+1:
By using the ocr_predictor
OnnxTR docTR
from onnxtr.io import DocumentFile
from onnxtr.models import detection_predictor as onnxtr_detection_predictor
from doctr.models import detection_predictor as doctr_detection_predictor
import numpy as np
doc = DocumentFile.from_images(["/home/felix/Desktop/doctr_test_data/test_page.jpg"])
# DOCTR
doctr_model = doctr_detection_predictor(arch="fast_base", pretrained=True)
res_doctr = doctr_model(doc, return_maps=True)[1][0]
print(res_doctr.shape)
# (1024, 1024, 1)
# ONNXTR
onnxtr_model = onnxtr_detection_predictor()
res_onnxtr = onnxtr_model(doc, return_maps=True)[1][0]
print(res_onnxtr.shape)
# (1024, 1024, 1)
# compute mean difference
mean_diff = np.abs(np.mean(res_doctr - res_onnxtr))
print(f"Mean difference between DOCTR and ONNXTR prob maps: {mean_diff:.8f}")
# Mean difference between DOCTR and ONNXTR: 0.07267220
# Model logits check
fast_onnxtr = onnxtr_fast_base()
fast_doctr = doctr_fast_base(pretrained=True, exportable=True).eval()
input_tensor = np.random.rand(1, 3, 1024, 1024).astype(np.float32)
torch_input_tensor = torch.from_numpy(input_tensor).float()
onnxtr_out = fast_onnxtr(input_tensor)
onnxtr_out = np.moveaxis(onnxtr_out["logits"], -1, 0) # logits output was added for testing purposes - not available in onnxtr
print(onnxtr_out.shape)
doctr_out = fast_doctr(torch_input_tensor)
doctr_out = doctr_out["logits"].detach().numpy()
# compute mean difference
mean_diff = np.abs(np.mean(doctr_out - onnxtr_out))
print(f"Mean difference between DOCTR and ONNXTR logits: {mean_diff:.8f}")
# Mean difference between DOCTR and ONNXTR logits: 0.00002960
Also compared other models:
DB resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00546016
Mean difference between DOCTR and ONNXTR logits: 0.00000007
Linknet resnet18
Mean difference between DOCTR and ONNXTR prob maps: 0.00332537
Mean difference between DOCTR and ONNXTR logits: 0.00000158
Compared to FAST models the diff is much smaller :+1: I don't think that i can do much about this yet but i will have an eye on it with the next model iteration so let's keep this issue open :)
Having issues as well with this. Results are vastly different, but I can confirm it's not due to the exported models because I have my own implementation of running ONNX DocTR exports which works fine.
Is there a workaround for this?
Hi @milosacimovic :wave:,
Which models do you use ? Exported from docTR v0.8.1 or v0.9.0a0 (main branch) ?
They have been exported from docTR v0.8.1
It seems that in ONNXTR there are three things that are different in preprocessing.
docTR:
onnxTR:
There also might be a bug/difference in how the initial resize (prior padding) is done
Another thing I noticed is different in onnxtr vs doctr are the per postprocessor bin_thresh and box_thresh. For OnnxTR the default of 0.5 is used for both while in doctr e.g. LinkNetPostProcessor has 0.1 used as default for both.
Hi @milosacimovic :wave:
First you are right there was a bug in resizing: I opened a PR would you like to test it ? (https://github.com/felixdittrich92/OnnxTR/pull/22)
bin + box thresh values are directly set in the model files (same as in doctr) https://github.com/felixdittrich92/OnnxTR/blob/03a12772fa991557aeb151e93a4530c9153ef95a/onnxtr/models/detection/models/linknet.py#L60
:)
Hi @felixdittrich92 my bad about bin_thresh, you're right, they're fine
Looks much closer with the PR now:
Overview:
Before:
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_resnet34
Mean difference between DOCTR and ONNXTR prob maps: 0.00266014
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00546016
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_mobilenet_v3_large
Mean difference between DOCTR and ONNXTR prob maps: 0.00505121
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet18
Mean difference between DOCTR and ONNXTR prob maps: 0.00332537
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet34
Mean difference between DOCTR and ONNXTR prob maps: 0.00293290
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00237311
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_tiny
Mean difference between DOCTR and ONNXTR prob maps: 0.07307008
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_small
Mean difference between DOCTR and ONNXTR prob maps: 0.07393986
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_base
Mean difference between DOCTR and ONNXTR prob maps: 0.07267220
After:
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_resnet34
Mean difference between DOCTR and ONNXTR prob maps: 0.00000007
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00005615
(1024, 1024, 1)
(1024, 1024, 1)
Det model: db_mobilenet_v3_large
Mean difference between DOCTR and ONNXTR prob maps: 0.00003883
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet18
Mean difference between DOCTR and ONNXTR prob maps: 0.00000575
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet34
Mean difference between DOCTR and ONNXTR prob maps: 0.00002059
(1024, 1024, 1)
(1024, 1024, 1)
Det model: linknet_resnet50
Mean difference between DOCTR and ONNXTR prob maps: 0.00000474
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_tiny
Mean difference between DOCTR and ONNXTR prob maps: 0.07032311
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_small
Mean difference between DOCTR and ONNXTR prob maps: 0.07004898
(1024, 1024, 1)
(1024, 1024, 1)
Det model: fast_base
Mean difference between DOCTR and ONNXTR prob maps: 0.07088637
The diff for the fast
models seems to come from the export.
Unfortunately this destroys the idea to drop PIL as dependency because before it was only used to get the font now it's also needed for the preproc to overcome the interpolation diff between PIL and CV2 :sweat_smile:
@decadance-dance @milosacimovic Could you test it again please from main branch ? :)
As mentioned the diff with fast
seems to come mainly from the export, unfortunately i can't do anything before we train this model again
pip3 install onnxtr[cpu]@git+https://github.com/felixdittrich92/OnnxTR.git
Works better thanks @felixdittrich92 , but still having some issue that I am having a problem to nail down which part is making the difference, the inputs are exactly the same but somehow, I'm getting different results.
@milosacimovic great 👍 Which model architectures do you use ?
Bug description
It is expected that converting to onnx will not affect the results. In my case (fast_base usage) this is not true and the onnx model works worse due to the fact that the regions of the probability mask are consistently narrower and sometimes this leads to artifacts.
Code snippet to reproduce the bug
Image:![1](https://github.com/felixdittrich92/OnnxTR/assets/86170544/56c24b47-71aa-46e1-9e63-06881039fc1a)
Error traceback
Here I provided different results both prob and bbox levels: onnx bboxes:
doctr bboxes:
![doctr](https://github.com/felixdittrich92/OnnxTR/assets/86170544/eba17e0c-c5f0-4ad7-b84b-a2dc54cfc053)
onnx prob map:
doctr prob map:
![doctr_probmap](https://github.com/felixdittrich92/OnnxTR/assets/86170544/5cd6fb7f-dba4-44f1-a16d-a614e7f06fe3)
Environment
Python 3.10.13 Ubuntu 20.04.5 LTS
onnxruntime 1.17.0 onnxruntime-gpu 1.17.1 onnxtr 0.1.2a0