microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.47k stars 2.9k forks source link

Drastically Different Result Across Multiple Languages Except Python #17829

Closed MikuAuahDark closed 1 year ago

MikuAuahDark commented 1 year ago

Describe the issue

I made a model that try to classify a gender based on their input names alone. For reference, here's the model in PyTorch:

class GolModel(torch.nn.Module):
    def __init__(self, nlen: int, nbits: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.upper_layer = nlen * 2
        self.lstm = torch.nn.LSTM(
            input_size=nbits,
            hidden_size=self.upper_layer,
            num_layers=1,
            bidirectional=False,
            batch_first=True,
        )
        self.fc1 = torch.nn.Linear(self.upper_layer * nlen, self.upper_layer)
        self.result = torch.nn.Linear(self.upper_layer, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        lstm, _ = self.lstm(x)
        lstm_r = lstm.reshape(lstm.size(0), -1)
        fc1 = self.fc1(lstm_r)
        result = torch.nn.functional.softmax(self.result(fc1), 1)
        return result

model = GolModel(64, 21)

(that aside, I only bash random operators on it, so please do not comment on my choice of NNs.)

The model expects n x 64 x 21 tensors where n is the batch size (must be set, 1 if necessary; dimension 1). The input is a string with maximum length of 64 characters (dimension 2), which then converted to its (reversed) binary representation from bits 0 to bits 20 (dimension 3).

For example, the text "a" has code point number of 97, thus it will be converted to [1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]. Roughly, the preprocessing code is as follows in Python:

def convert_name_to_numpy(names: str):
    tensor = numpy.zeros((len(names), MAX_CODEPOINT_LEN, NUMBER_OF_BITS), numpy.float32)

    for n, name in enumerate(names):
        for i in range(min(len(name), MAX_CODEPOINT_LEN)):
            intval = ord(name[i])
            for j in range(NUMBER_OF_BITS):
                tensor[n, i, j] = bool(intval & (1 << j))

    return tensor

(to elaborate, the same code is then rewritten exactly 1:1 to JavaScript and Java)

The tensor is then passed for inference. Now the issue is, while the model is almost has same accuracy when run in Python, it returns completely different result across Java and JavaScript (see below).

Note: I'm using ONNX Runtime 1.16.0 on Python, Java, and JavaScript. The language selector doesn't allow me to pick multiple languages.

To reproduce

Model and code can be downloaded here: complete_reproduction.zip

In there, you can find these files:

Here's the ONNX Python inference output:

D:\omitted>python hitit_v8_fastinfer.py N Na Nav Navi Navia
(debug tensor output omitted)
N
Male: 67.00568795204163 %
Female: 32.99431502819061 %
Guessed: Unisex

Na
Male: 28.068840503692627 %
Female: 71.93116545677185 %
Guessed: Unisex

Nav
Male: 82.9777717590332 %
Female: 17.022231221199036 %
Guessed: Male

Navi
Male: 41.178399324417114 %
Female: 58.821600675582886 %
Guessed: Unisex

Navia
Male: 4.214639961719513 %
Female: 95.7853615283966 %
Guessed: Female

(Python 3.11.4)

Here's JavaScript inference output while typing "Navia", word by word logged using console.log() (added gender hint for clarity):

output Float32Array [ 0.6700569987297058, 0.3299430012702942 ] namegender.html:79:13 (UNISEX)
output Float32Array [ 0.6402597427368164, 0.359740287065506 ] namegender.html:79:13 (UNISEX)
output Float32Array [ 0.8092355132102966, 0.19076451659202576 ] namegender.html:79:13 (MALE)
output Float32Array [ 0.7302872538566589, 0.26971277594566345 ] namegender.html:79:13 (MALE)
output Float32Array [ 0.8071557283401489, 0.19284430146217346 ] namegender.html:79:13 (MALE)

(Mozilla Firefox 118.0.1)

And here's the output in Java:

D:\omitted>gradlew run --args "N Na Nav Navi Navia"
To honour the JVM settings for this build a single-use Daemon process will be forked. See https://docs.gradle.org/7.3/userguide/gradle_daemon.html#sec:disabling_the_daemon.
Daemon will be stopped at the end of the build

> Task :app:run
N
Male: 50.220966339111 %
Female: 49.779033660889 %
Guessed: Unisex

Na
Male: 15.308277308941 %
Female: 84.691721200943 %
Guessed: Female

Nav
Male: 1.0724958963692 %
Female: 98.927503824234 %
Guessed: Female

Navi
Male: 95.465511083603 %
Female: 4.5344825834036 %
Guessed: Male

Navia
Male: 98.757290840149 %
Female: 1.2427097186446 %
Guessed: Male

BUILD SUCCESSFUL in 6s
3 actionable tasks: 1 executed, 2 up-to-date

(JVM 17.0.6 Microsoft OpenJDK)

And finally, here's the output when inferencing from PyTorch directly. Should be used for baseline reference:

D:\omitted>python hitit_v8.py infer N Na Nav Navi Navia
N
Male: 67.00567603111267 %
Female: 32.99432694911957 %
Guessed: Unisex

Na
Male: 28.06881070137024 %
Female: 71.93118929862976 %
Guessed: Unisex

Nav
Male: 82.9777479171753 %
Female: 17.022255063056946 %
Guessed: Male

Navi
Male: 41.17844998836517 %
Female: 58.82155895233154 %
Guessed: Unisex

Navia
Male: 4.214644432067871 %
Female: 95.78534960746765 %
Guessed: Female

(PyTorch 2.1.0 + Python 3.11.4)

Output TL;DR: Output in Python is same as in PyTorch and ONNX. However, using other languages, the result went vastly different and they're not even equal each other.

Urgency

It's a hobby project, so it's not an urgent. However, I'd love to have it resolved as soon as possible.

Platform

Windows

OS Version

10.0.22621.2361

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.16.0

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

-

Craigacp commented 1 year ago

The Java code doesn't produce a binary array, it's writing the position integer into the tensor rather than setting a bit. You want dest[i][j] = (int) ((intval & (1 << j)) > 0); which is the equivalent of passing it through Python's bool function which is what hitit_v8_fastinfer.py does.

You might also want to check the character encoding, I think in Python you're processing it in UTF-8, but Java Strings are UTF-16.

In Javascript I think your indexing into the 1d array is broken, it probably should be floatArray[i * NUMBER_OF_BITS + j], but I think there are other problems in that code too.

MikuAuahDark commented 1 year ago

Hello, thank you. I double checked, applied the change, and indeed they're now provide consistent result.

As for the character encoding, it's not an issue. I use ord() in Python and String.codePointAt() in JavaScript and Java. The latter may not able to handle surrogate pairs correctly. However, for the sake of simplicity of the model, I trained it using ASCII character range (which means NUMBER_OF_BITS can be down to 7).

Sorry for the inconvenience.