Closed MikuAuahDark closed 1 year ago
The Java code doesn't produce a binary array, it's writing the position integer into the tensor rather than setting a bit. You want dest[i][j] = (int) ((intval & (1 << j)) > 0);
which is the equivalent of passing it through Python's bool function which is what hitit_v8_fastinfer.py
does.
You might also want to check the character encoding, I think in Python you're processing it in UTF-8, but Java Strings are UTF-16.
In Javascript I think your indexing into the 1d array is broken, it probably should be floatArray[i * NUMBER_OF_BITS + j]
, but I think there are other problems in that code too.
Hello, thank you. I double checked, applied the change, and indeed they're now provide consistent result.
As for the character encoding, it's not an issue. I use ord()
in Python and String.codePointAt()
in JavaScript and Java. The latter may not able to handle surrogate pairs correctly. However, for the sake of simplicity of the model, I trained it using ASCII character range (which means NUMBER_OF_BITS
can be down to 7).
Sorry for the inconvenience.
Describe the issue
I made a model that try to classify a gender based on their input names alone. For reference, here's the model in PyTorch:
(that aside, I only bash random operators on it, so please do not comment on my choice of NNs.)
The model expects
n x 64 x 21
tensors wheren
is the batch size (must be set, 1 if necessary; dimension 1). The input is a string with maximum length of 64 characters (dimension 2), which then converted to its (reversed) binary representation from bits 0 to bits 20 (dimension 3).For example, the text "a" has code point number of 97, thus it will be converted to
[1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
. Roughly, the preprocessing code is as follows in Python:(to elaborate, the same code is then rewritten exactly 1:1 to JavaScript and Java)
The tensor is then passed for inference. Now the issue is, while the model is almost has same accuracy when run in Python, it returns completely different result across Java and JavaScript (see below).
Note: I'm using ONNX Runtime 1.16.0 on Python, Java, and JavaScript. The language selector doesn't allow me to pick multiple languages.
To reproduce
Model and code can be downloaded here: complete_reproduction.zip
In there, you can find these files:
App.java
- Java version of the inference.hitit_v8.py
- Complete Python definition of the model, how it's trained in PyTorch and exported to ONNX.hitit_v8_fastinfer.py
- ONNX inference using Python.namegender.html
- HTML + JavaScript realtime inference.modelgol2.onnx
- The model.Here's the ONNX Python inference output:
(Python 3.11.4)
Here's JavaScript inference output while typing "Navia", word by word logged using
console.log()
(added gender hint for clarity):(Mozilla Firefox 118.0.1)
And here's the output in Java:
(JVM 17.0.6 Microsoft OpenJDK)
And finally, here's the output when inferencing from PyTorch directly. Should be used for baseline reference:
(PyTorch 2.1.0 + Python 3.11.4)
Output TL;DR: Output in Python is same as in PyTorch and ONNX. However, using other languages, the result went vastly different and they're not even equal each other.
Urgency
It's a hobby project, so it's not an urgent. However, I'd love to have it resolved as soon as possible.
Platform
Windows
OS Version
10.0.22621.2361
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.16.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
-