microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.43k stars 2.9k forks source link

Incorrect result for converted FP16 model with Conv Op when run on arm64 Linux with onnxruntime >= 1.15.0 #18992

Closed jasonkit closed 2 months ago

jasonkit commented 9 months ago

Describe the issue

An onnx model which are exported from PyTorch with nn.Conv2 and converted to FP16 are not giving correct result during inference.

This issue is not observed on the original exported FP32 onnx model This issue also not observed on onnxruntime 1.13 or .1.14. I first observe it on onnxruntime >= 1.15.0 Also this issue is only observed on arm64 linux (actually I observe this issue on docker running on M1 macOS). It works fine on macOS with M1 CPU, or Linux with intel CPU.

To reproduce

On arm64 Linux (or using python:3.10-bullseye docker image), run following code with onnxruntime >= 1.15.0

import torch
from torch import nn

import onnx
from onnxconverter_common import float16
import onnxruntime as ort
import numpy as np

class ModelUnderTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Conv2d(1, 1, 1)
        nn.init.constant_(self.model.weight.data, 0.5)
        if self.model.bias is not None:
            # It works fine for this test case if bias is initialised to 0
            nn.init.constant_(self.model.bias.data, 0.5)

    def forward(self, x):
        return self.model(x)

if __name__ == "__main__":
    m = ModelUnderTest()
    x = torch.ones(1, 1, 1)
    torch.onnx.export(m, x, "m1.onnx", export_params=True)

    model = onnx.load("m1.onnx")
    m_16 = float16.convert_float_to_float16(
        model,
        keep_io_types=True,
        # It works fine if we block Conv Op
        # op_block_list=float16.DEFAULT_OP_BLOCK_LIST + ["Conv"],
    )
    onnx.save(m_16, "m1_fp16.onnx")

    # ---

    session_option = ort.SessionOptions()
    session_option.log_severity_level = 3
    session_option.enable_cpu_mem_arena = False
    session_option.enable_mem_pattern = False
    session_option.enable_mem_reuse = False

    x = np.ones((1, 1, 1))
    session_fp32 = ort.InferenceSession("m1.onnx", session_option)
    y1 = session_fp32.run(None, {"input": x.astype(np.float32)})[0]
    print("fp32 output")
    print(y1)
    session_fp16 = ort.InferenceSession("m1_fp16.onnx", session_option)
    y2 = session_fp16.run(None, {"input": x.astype(np.float32)})[0]
    print("fp16 output")
    print(y2)

    y_diff = y1 - y2
    y_diff_2 = y_diff * y_diff
    print("SSD")
    print(np.sum(y_diff_2))

It prints

fp32 output
[[[1.]]]
fp16 output
[[[0.5]]]
SSD
0.25

However, the expected output should be

fp32 output
[[[1.]]]
fp16 output
[[[1.]]]
SSD
0.0

It gives the correct output when downgrade onnxruntime to 1.14.1

Urgency

This seems to be a degrade on onnxruntime as it works before 1.15.0. I can workaround the issue by adding Conv to op_block_list when converting the model to fp16.

Platform

Linux

OS Version

Debian Bullseye

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

>= 1.15.0

ONNX Runtime API

Python

Architecture

ARM64

Execution Provider

Default CPU

Execution Provider Library Version

No response

wschin commented 9 months ago

This problem has been fixed in latest main branch. Please install nightly version from this page.

wschin commented 9 months ago

Close for now. Feel free to re-open. Thanks.

jasonkit commented 9 months ago

@wschin

I have tried my code snippet above with ort-nightly==1.17.0.dev20240103001 However I am still getting the same incorrect output.

Following are the installed python package version

root@a68dcd6fb452:/app# pip freeze
coloredlogs==15.0.1
filelock==3.13.1
flatbuffers==23.5.26
fsspec==2023.12.2
humanfriendly==10.0
Jinja2==3.1.2
MarkupSafe==2.1.3
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.3
onnx==1.15.0
onnxconverter-common==1.14.0
ort-nightly==1.17.0.dev20240103001
packaging==23.2
protobuf==3.20.2
sympy==1.12
torch==2.1.2
typing_extensions==4.9.0
jasonkit commented 9 months ago

@wschin

From my observation, It looks like when running a fp16 model on arm64 linux with onnxruntime >= 1.15.0 (even with nightly build), the bias of the Conv Op is get ignored.

If I exported the model that setting 0 to Conv2's bias, the computation result of fp16 model will match the fp32 one.

wschin commented 9 months ago

My pip freeze output:

astunparse==1.6.3
attrs==23.2.0
black==23.10.1
Cerberus==1.3.5
certifi==2023.11.17
charset-normalizer==3.3.2
clang-format==17.0.4
click==8.1.7
expecttest==0.2.1
filelock==3.13.1
flatbuffers==23.5.26
fsspec==2023.12.2
h5py==3.10.0
hypothesis==6.92.2
idna==3.6
isort==5.12.0
Jinja2==3.1.2
lintrunner==0.11.0
lintrunner-adapters==0.12.1
MarkupSafe==2.1.3
mpi4py @ file:///work/ci_py311/mpi4py_1676858691457/work
mpmath==1.3.0
mypy-extensions==1.0.0
networkx==3.2.1
numpy==1.26.2
onnx==1.15.0
onnxconverter-common==1.14.0
onnxruntime-training @ Debug/dist/onnxruntime_training-1.17.0%2Bcu121-cp311-cp311-linux_x86_64.whl
onnxscript==0.1.0.dev20240103
optree==0.10.0
packaging==23.2
pathspec==0.12.1
platformdirs==4.1.0
protobuf==3.20.2
psutil==5.9.7
PyYAML==6.0.1
requests==2.31.0
ruff==0.1.4
six==1.16.0
sortedcontainers==2.4.0
sympy==1.12
-e git+https://github.com/pytorch/pytorch.git@b18d8d4595aa6e0768eedd5fc7d4a4402c567181#egg=torch
types-dataclasses==0.6.6
typing_extensions==4.9.0
urllib3==2.1.0

FYI: I built PyTorch and ORT locally with a commit two days ago.

jasonkit commented 9 months ago

@wschin

Just want to confirm, your testing environment is on arm64/aarch64 Linux? As there is no issue on Intel CPU.

My uname -a is Linux 78b33872b873 6.4.16-linuxkit #1 SMP PREEMPT Thu Nov 16 10:49:20 UTC 2023 aarch64 GNU/Linux

Actually I am testing in docker, you should be able to reproduce my environment by using docker

docker run --platform linux/arm64 --rm -it python:3.10-bullseye bash

and following requirements.txt

--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/
coloredlogs==15.0.1
filelock==3.13.1
flatbuffers==23.5.26
fsspec==2023.12.2
humanfriendly==10.0
Jinja2==3.1.2
MarkupSafe==2.1.3
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.3
onnx==1.15.0
onnxconverter-common==1.14.0
packaging==23.2
protobuf==3.20.2
sympy==1.12
torch==2.1.2
typing_extensions==4.9.0
ort-nightly==1.17.0.dev20240103001
yihonglyu commented 3 months ago

@jasonkit, I tried to reproduce your issue on Windows 11 ARM64 but cannot reproduce it. Here is my package versions:

(1.15.0) >conda list
# packages in environment at C:\Users\yilyu\.conda\envs\1.15.0:
#
# Name                    Version                   Build  Channel
black                     24.4.2                   pypi_0    pypi
ca-certificates           2024.3.11            haa95532_0
cerberus                  1.3.5                    pypi_0    pypi
click                     8.1.7                    pypi_0    pypi
colorama                  0.4.6                    pypi_0    pypi
coloredlogs               15.0.1                   pypi_0    pypi
coverage                  7.5.4                    pypi_0    pypi
exceptiongroup            1.2.1                    pypi_0    pypi
filelock                  3.15.4                   pypi_0    pypi
flatbuffers               24.3.25                  pypi_0    pypi
fsspec                    2024.6.1                 pypi_0    pypi
humanfriendly             10.0                     pypi_0    pypi
iniconfig                 2.0.0                    pypi_0    pypi
intel-openmp              2021.4.0                 pypi_0    pypi
isort                     5.13.2                   pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
joblib                    1.4.2                    pypi_0    pypi
libffi                    3.4.4                hd77b12b_1
markupsafe                2.1.5                    pypi_0    pypi
mkl                       2021.4.0                 pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
mypy-extensions           1.0.0                    pypi_0    pypi
networkx                  3.1                      pypi_0    pypi
numpy                     1.24.4                   pypi_0    pypi
onnx                      1.16.1                   pypi_0    pypi
onnxconverter-common      1.14.0                   pypi_0    pypi
onnxmltools               1.12.0                   pypi_0    pypi
onnxruntime               1.15.0                   pypi_0    pypi
openssl                   3.0.14               h827c3e9_0
packaging                 24.1                     pypi_0    pypi
pandas                    2.0.3                    pypi_0    pypi
parameterized             0.9.0                    pypi_0    pypi
pathspec                  0.12.1                   pypi_0    pypi
pip                       24.0             py38haa95532_0
platformdirs              4.2.2                    pypi_0    pypi
pluggy                    1.5.0                    pypi_0    pypi
protobuf                  3.20.2                   pypi_0    pypi
pydocstyle                6.3.0                    pypi_0    pypi
pyreadline3               3.4.1                    pypi_0    pypi
pytest                    8.2.2                    pypi_0    pypi
pytest-cov                5.0.0                    pypi_0    pypi
python                    3.8.19               h1aa4202_0
python-dateutil           2.9.0.post0              pypi_0    pypi
pytz                      2024.1                   pypi_0    pypi
scikit-learn              1.3.2                    pypi_0    pypi
scipy                     1.10.1                   pypi_0    pypi
setuptools                69.5.1           py38haa95532_0
six                       1.16.0                   pypi_0    pypi
snowballstemmer           2.2.0                    pypi_0    pypi
sqlite                    3.45.3               h2bbff1b_0
sympy                     1.12.1                   pypi_0    pypi
tbb                       2021.13.0                pypi_0    pypi
threadpoolctl             3.5.0                    pypi_0    pypi
tomli                     2.0.1                    pypi_0    pypi
torch                     2.3.1                    pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
tzdata                    2024.1                   pypi_0    pypi
vc                        14.2                 h2eaa2aa_4
vs2015_runtime            14.29.30133          h43f2093_4
wheel                     0.43.0           py38haa95532_0

Could you reproduce the issue on Windows 11 ARM64?

jasonkit commented 3 months ago

@yihonglyu

Sorry that, I don't have access to Window 11 ARM64 machine.

Actually my reported issue is happened on Linux ARM64, not on Window 11.

As I mentioned in the issue description

It works fine on macOS with M1 CPU, or Linux with intel CPU.

I suspect that the same issue might not appear on Window 11 ARM64 machine, and might only reproducible on Linux Arm64 machine.

The environment I used to reproduce the issue is mentioned in https://github.com/microsoft/onnxruntime/issues/18992#issuecomment-1878410088 You may try to reproduce that with the specified environment.

I have just re-run the test with following package

torch==2.3.1
numpy==1.26.4
onnx==1.16.1
onnxconverter-common==1.14.0
onnxruntime==1.18.1

And the issue still exists.

If you don't have access to docker, and you Window 11 ARM64, could you try reproduce the issue on WSL?

yihonglyu commented 2 months ago

@jasonkit This issue has been resolved.

jasonkit commented 2 months ago

@yihonglyu

Confirmed this issue has resolve in latest build, thanks!