microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.65k stars 2.93k forks source link

SplitToSequence cannot support float16 as input/output #16006

Open xiaowuhu opened 1 year ago

xiaowuhu commented 1 year ago

Describe the issue

SplitToSequence operator cannot support float16 as input, although it was said 'yes' in ONNX doc. It impact fp16 converter. Because the output is a sequence, so user have to convert the element in the output sequence to fp16 one by one, the performance is bad.

To reproduce

  1. prepare a [5,5] input array with dtype = np.float16
  2. call op.SplitToSequence(input, dim=0, num_outputs=2)
  3. the expected output is sequence([2,5], [3,5]).

It works on float32.

Urgency

No response

Platform

Windows

OS Version

11

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

No response

justinchuby commented 1 year ago

The error persists in 1.15.1: onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (seq(tensor(float16))) of output arg (output0) of node () does not match expected type (seq(tensor(float))).

justinchuby commented 1 year ago

Summary

ORT raises [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (seq(tensor(float16))) of output arg (_val_1) of node (_0x79a4c20_n23) does not match expected type (seq(tensor(float))). when executing test ops_test.TestOutputConsistencyFullGraphCPU.test_output_match_opinfo__chunk_cpu_float16 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__chunk_cpu_float16

To reproduce

import onnx
import onnxruntime as ort
import numpy as np
from numpy import array, float16, float32, float64, int32, int64

onnx_model_text = """
<
   ir_version: 8,
   opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18],
   producer_name: "pytorch",
   producer_version: "2.1.0"
>
torch_jit (float16[5,5,5] input_0) => (seq(float16[5,unk__4,5]) _val_1) {
   _val_1 = pkg.onnxscript.torch_lib.aten_chunk <chunks = 5, dim = 1> (input_0)
}
<
  domain: "pkg.onnxscript.torch_lib",
  opset_import: ["" : 18]
>
aten_chunk <chunks>(self) => (return_val)
{
   neg_1 = Constant <value_ints = [-1]> ()
   self_shape = Shape (self)
   dim = Constant <value_int: int = @dim> ()
   dim_size = Gather <axis = 0> (self_shape, dim)
   chunks = Constant <value_int: int = @chunks> ()
   chunks_cast = CastLike (chunks, dim_size)
   num_per_chunk = Div (dim_size, chunks_cast)
   chunks_0 = Constant <value_int: int = @chunks> ()
   chunks_0_cast = CastLike (chunks_0, dim_size)
   tmp = Mod (dim_size, chunks_0_cast)
   int64_0 = Constant <value = int64 int64_0 {0}> ()
   int64_0_cast = CastLike (int64_0, tmp)
   tmp_1 = Greater (tmp, int64_0_cast)
   tmp_2 = Cast <to = 7> (tmp_1)
   num_per_chunk_3 = Add (tmp_2, num_per_chunk)
   num_chunk = Div (dim_size, num_per_chunk_3)
   tmp_4 = Reshape (num_chunk, neg_1)
   list_split = Expand (num_per_chunk_3, tmp_4)
   remainder = Mod (dim_size, num_per_chunk_3)
   int64_0_5 = Constant <value = int64 int64_0_5 {0}> ()
   int64_0_5_cast = CastLike (int64_0_5, remainder)
   cond = Greater (remainder, int64_0_5_cast)
   list_split_9 = If (cond) <then_branch = thenGraph_19 () => ( list_split_7) {
      tmp_6 = Reshape (remainder, neg_1)
      list_split_7 = Concat <axis = 0> (list_split, tmp_6)
   }, else_branch = elseGraph_19 () => ( list_split_8) {
      list_split_8 = Identity (list_split)
   }>
   return_val = SplitToSequence <axis: int = @dim> (self, list_split_9)
}
"""

ort_inputs = {'input_0': array([[[-2.047  ,  8.09   , -7.1    , -2.293  , -7.355  ],
        [-5.668  , -4.676  ,  0.3516 ,  1.371  , -8.875  ],
        [ 7.137  , -8.44   ,  7.523  ,  7.367  , -4.43   ],
        [ 3.016  ,  1.125  ,  8.81   , -3.312  ,  4.14   ],
        [ 0.545  ,  1.213  ,  4.375  , -3.797  , -5.562  ]],

       [[ 5.92   , -5.33   , -6.47   , -5.68   ,  6.785  ],
        [ 4.297  , -6.977  , -0.06152, -8.65   ,  0.2373 ],
        [-7.82   , -7.242  ,  7.375  , -2.152  , -0.835  ],
        [ 0.2812 ,  0.413  , -4.586  , -5.43   ,  5.035  ],
        [ 6.39   , -1.934  , -8.14   , -2.996  ,  7.656  ]],

       [[-2.629  ,  8.664  ,  4.797  , -0.5625 ,  6.484  ],
        [-3.621  ,  8.28   ,  4.05   , -3.357  ,  6.75   ],
        [ 0.03516,  1.907  , -4.586  , -2.268  , -5.51   ],
        [-1.354  , -2.021  , -5.555  , -7.188  ,  0.12305],
        [ 7.53   ,  4.86   , -1.169  ,  4.043  ,  6.062  ]],

       [[ 5.047  , -1.415  , -2.479  ,  2.11   , -4.05   ],
        [-0.03516, -8.164  ,  7.902  ,  6.44   , -4.746  ],
        [ 3.568  , -6.977  , -2.426  ,  5.75   ,  1.494  ],
        [-4.254  ,  0.3076 , -4.395  ,  1.397  ,  7.367  ],
        [ 6.133  ,  2.127  ,  0.747  ,  6.99   , -3.93   ]],

       [[-3.016  , -4.113  ,  5.035  ,  8.37   , -0.10547],
        [-3.332  , -2.012  ,  0.2373 ,  5.44   ,  4.88   ],
        [ 2.98   ,  7.004  ,  8.414  ,  2.9    ,  5.617  ],
        [ 4.438  ,  8.914  ,  3.05   ,  4.15   ,  2.021  ],
        [-7.34   ,  1.969  ,  3.375  ,  3.305  ,  2.479  ]]],
      dtype=float16)}

session_options = ort.SessionOptions()
session_options.graph_optimization_level = (
    ort.GraphOptimizationLevel.ORT_DISABLE_ALL
)
onnx_model = onnx.parser.parse_model(onnx_model_text)

session = ort.InferenceSession(
    onnx_model.SerializeToString(), session_options, providers=("CPUExecutionProvider",)
)
ort_outputs = session.run(None, ort_inputs)

Full error stack

Traceback (most recent call last):
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test_common.py", line 533, in _capture_graph_and_evaluate_torch_script_evaluator
    return _safe_ort_session_run(onnx_model.SerializeToString(), ort_inputs)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test_common.py", line 349, in _safe_ort_session_run
    raise return_dict["error"]
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Type Error: Type (seq(tensor(float16))) of output arg (_val_1) of node (_0x79a4c20_n23) does not match expected type (seq(tensor(float))).
centwang commented 1 year ago

https://github.com/microsoft/onnxruntime/pull/17117 is trying to fix this.

fxmarty commented 1 year ago

Hi, I am facing the same issue for models using torch.repeat_interleave.

Edit: the issue is actually "fixed" upstream (no more SplitToSequence) in the export in pytorch 2.1, thanks to https://github.com/pytorch/pytorch/pull/100575

MaanavD commented 7 months ago

Fixed upstream. Closing issue. Thanks all.

justinchuby commented 7 months ago

@MaanavD we prob should keep any ort issues tagged with "dynamo" open

thiagocrepaldi commented 6 months ago

@MaanavD we prob should keep any ort issues tagged with "dynamo" open

Was the issue fixed? if so, it is safe to close. Otherwise, we need to figure out whether this is an ORT or Torchlib issue and work to fix it

justinchuby commented 6 months ago

This is an ORT issue. We can run the repro script again to validate with the latest version.