microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.59k stars 2.92k forks source link

Output mismatch of duplicate torch.Tensor.to nodes after optimization #18211

Open Azyka opened 1 year ago

Azyka commented 1 year ago

Describe the issue

ONNX opset version: 14 When 2 duplicate nodes(which have the same inputs and outputs) of torch.Tensor.to are defined, the model produce wrong results after ort optimization.

To reproduce

Model code:

import onnxruntime as ort
import onnx
import numpy as np
from numpy import testing
import torch

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0]
        tan = torch.tan(getitem)
        to = tan.to(dtype = torch.int64)
        to_e = tan.to(dtype = torch.int64)
        to_1 = to_e.to(dtype = torch.bool)
        return (to, to_1)

model_0 = Model0()
output_names_0 = ['v4_0', 'v3_0']
input_data_0 = np.array(3.645, dtype=np.float32)
input_dict_0 = {'v5_0':input_data_0}
inputs_0 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_0.items())
torch.onnx.export(model_0, inputs_0, '0.onnx', verbose=False, input_names=['v5_0'], output_names=output_names_0, opset_version=14, do_constant_folding=False)

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        _args = args
        getitem = _args[0]
        tan = torch.tan(getitem)
        to = tan.to(dtype = torch.int64)
        to_1 = to.to(dtype = torch.bool)
        to_e = tan.to(dtype = torch.int64)
        return (to_e, to_1)

model_1 = Model1()
output_names_1 = ['v5_0', 'v7_0']
input_dict_1 = {'v0_0':input_data_0}
inputs_1 = tuple(torch.from_numpy(v).to('cpu') for _, v in input_dict_1.items())
torch.onnx.export(model_1, inputs_1, '1.onnx', verbose=False, input_names=['v0_0'], output_names=output_names_1, opset_version=14, do_constant_folding=False)

sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))
output_name_dict = {'v4_0': 'v5_0', 'v3_0': 'v7_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        print(tensor_name_0, tensor_name_1)
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_enable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_enable_opt triggers assertion")
    print(e)
print('=========================')

sess_options_0 = ort.SessionOptions()
sess_options_0.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_0 = ort.InferenceSession('0.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_0)
sess_res_0 = sess_0.run(output_names_0, input_dict_0)
output_0 = dict(zip(output_names_0, sess_res_0))

sess_options_1 = ort.SessionOptions()
sess_options_1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_1 = ort.InferenceSession('1.onnx',providers=['CPUExecutionProvider'],sess_options=sess_options_1)
sess_res_1 = sess_1.run(output_names_1, input_dict_1)
output_1 = dict(zip(output_names_1, sess_res_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1])
    print("onnxruntime_disable_opt does not trigger assertion")
except AssertionError as e:
    print("onnxruntime_disable_opt triggers assertion")
    print(e)
print('=========================')

Output:

=========================
onnxruntime_enable_opt triggers assertion

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1 / 1 (100%)
 x: array(True)
 y: array(False)
=========================
=========================
onnxruntime_disable_opt does not trigger assertion
=========================

Model0 produces array(True) after optimization, which is supposed to be array(False).

Urgency

This is an incorrect functionality implementation. It may cause severe bugs for those systems on the top of ORT.

Platform

Linux

OS Version

Ubuntu 22.04.3 LTS (x86_64)

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.15.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

github-actions[bot] commented 11 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details.

thiagocrepaldi commented 11 months ago

@Azyka this might be fixed after https://github.com/pytorch/pytorch/pull/96320 if you set torch.onnx.export(..., keep_initializers_as_inputs=True

Try it out and let us know how it goes

Azyka commented 11 months ago

@Azyka this might be fixed after pytorch/pytorch#96320 if you set torch.onnx.export(..., keep_initializers_as_inputs=True

Try it out and let us know how it goes

@thiagocrepaldi I tried the keep_initializers_as_inputs=True with the latest torch version, and the error still exists.

=========================
onnxruntime_enable_opt triggers assertion

Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1 / 1 (100%)
 x: array(True)
 y: array(False)
=========================
=========================
onnxruntime_disable_opt does not trigger assertion
=========================
github-actions[bot] commented 10 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.