microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.59k stars 2.92k forks source link

Output mismatch of torch.Tensor.to due to an extra torch.Tensor.transpose node #18231

Open Azyka opened 12 months ago

Azyka commented 12 months ago

Describe the issue

ONNX opset version: 14

When adding an extra torch.Tensor.transpose node as output, the original output of torch.Tensor.to is changed with and without optimization, causing mismatch of model outputs.

To reproduce

Sample code:

import onnxruntime as ort
import onnx
import numpy as np
from numpy import testing
import torch

p0 = torch.nn.Parameter(torch.empty([51, 1, 1, 13], dtype=torch.float16), requires_grad=False)

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.v6_0 = p0

    def forward(self, *args):
        _args = args
        v6_0 = self.v6_0
        getitem = _args[0]
        mul = torch.mul(getitem, v6_0)
        to = mul.to(dtype = torch.float64)
        return (to)

model_0 = Model0()
output_names_0 = ['v4_0']
input_data_0 = np.array([[[[3.021],[6.277],[5.758],[3.828],[3.314],[6.816]]]], dtype=np.float16)
input_dict_0 = {'v5_0':input_data_0}
inputs_0 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_0.items())
torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, input_names=['v5_0'], output_names=output_names_0, opset_version=14, do_constant_folding=False)

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.v2_0 = p0

    def forward(self, *args):
        _args = args
        v2_0 = self.v2_0
        getitem = _args[0]
        mul = torch.mul(getitem, v2_0)
        transpose_2 = mul.transpose(0, 1)
        to = mul.to(dtype = torch.float64)
        return (transpose_2, to)

model_1 = Model1()
output_names_1 = ['v5_0', 'v9_0']
input_dict_1 = {'v0_0':input_data_0}
inputs_1 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_1.items())
torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, input_names=['v0_0'], output_names=output_names_1, opset_version=14, do_constant_folding=False)

sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))
output_name_dict = {'v4_0': 'v9_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_enable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_enable_opt triggers assertion")
    print(e)
print('=========================')

sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_disable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_disable_opt triggers assertion")
    print(e)
print('=========================')

Output(varible due to torch.nn.Parameter):

=========================
onnxruntime_enable_opt triggers assertion

Not equal to tolerance rtol=1e-07, atol=0

x and y +inf location mismatch:
 x: array([[[[1.983302e+04, 1.362085e+04, 2.593361e-02, ..., 2.434301e-02,
          2.434301e-02, 9.180478e+04],
         [4.120448e+04, 2.829827e+04, 5.387887e-02, ..., 5.057430e-02,...
 y: array([[[[1.984000e+04, 1.362400e+04, 2.593994e-02, ..., 2.433777e-02,
          2.433777e-02,          inf],
         [4.121600e+04, 2.830400e+04, 5.389404e-02, ..., 5.056763e-02,...
=========================
=========================
onnxruntime_disable_opt triggers assertion

Not equal to tolerance rtol=1e-07, atol=0

x and y +inf location mismatch:
 x: array([[[[1.983302e+04, 1.362085e+04, 2.593361e-02, ..., 2.434301e-02,
          2.434301e-02, 9.180478e+04],
         [4.120448e+04, 2.829827e+04, 5.387887e-02, ..., 5.057430e-02,...
 y: array([[[[1.984000e+04, 1.362400e+04, 2.593994e-02, ..., 2.433777e-02,
          2.433777e-02,          inf],
         [4.121600e+04, 2.830400e+04, 5.389404e-02, ..., 5.056763e-02,...
=========================

Urgency

This is an incorrect functionality implementation. It may cause severe bugs for those systems on the top of ORT.

Platform

Linux

OS Version

Ubuntu 22.04.3 LTS (x86_64)

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

github-actions[bot] commented 11 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.