onnx / onnx-tensorrt

ONNX-TensorRT: TensorRT backend for ONNX
Apache License 2.0
2.94k stars 544 forks source link

Assertion fail in fillShapeVector when using tf.image.crop_and_resize #917

Open vdel opened 1 year ago

vdel commented 1 year ago

Description

My attempts at performing an inference for a Faster-RCNN model lead to a segmentation fault of Python. The problem seems related to the tf.image.crop_and_resize operation. I can reproduce the issue with the following POC model (see below for the full code to reproduce):

import tensorflow as tf

# Numbers are arbitrary
image = tf.keras.Input([52, 52, 10])
boxes = tf.constant(0, dtype=tf.float32, shape=[10, 4])
boxes_ind = tf.constant(0, dtype=tf.int32, shape=[10])
output = tf.image.crop_and_resize(image, boxes, boxes_ind, [15, 15])
model = tf.keras.Model(image, output)

Environment

TensorRT Version: 8.6.1 ONNX-TensorRT Version / Branch: release/8.6-GA GPU Type: tested on a GeForce RTX 2060 Nvidia Driver Version: 525.105.17 CUDA Version: 12.0 CUDNN Version: 8.9.0 Operating System + Version: Ubuntu 20.04.5 Python Version (if applicable): 3.8.10 TensorFlow + TF2ONNX Version (if applicable): TF 2.12.0 / TF2ONNX 1.14.0 PyTorch Version (if applicable): N/A Baremetal or Container (if container which image + tag): Container: nvcr.io/nvidia/tensorrt:23.04-py3

Relevant Files

See dockerfile.zip. The archive contains:

Steps To Reproduce

The poc.py is as follow:

import os
import tensorflow as tf

from tf2onnx import tf_loader
from tf2onnx import constants
from tf2onnx.convert import _convert_common

import onnx
import onnxruntime as ort
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

# Build and save the model
model_dir = 'repro'
image = tf.keras.Input([52, 52, 10])
boxes = tf.constant(0, dtype=tf.float32, shape=[10, 4])
boxes_ind = tf.constant(0, dtype=tf.int32, shape=[10])
output = tf.image.crop_and_resize(image, boxes, boxes_ind, [15, 15])
model = tf.keras.Model(image, output)
model.save(model_dir)

# Convert the TF model into ONNX
def convert(saved_model_dir, output_path, inputs=None, outputs=None, tag='serve', signature_def='serving_default', opset=17):
    graph_def, inputs, outputs, initialized_tables, tensors_to_rename = tf_loader.from_saved_model(
        saved_model_dir, inputs, outputs, tag, [signature_def], None,
        False, return_initialized_tables=True, return_tensors_to_rename=True,
        use_graph_names=False)
    with tf.device("/cpu:0"):
        model_proto, _ = _convert_common(
            graph_def,
            name=saved_model_dir,
            continue_on_error=False,
            target=",".join(constants.DEFAULT_TARGET),
            opset=opset,
            custom_op_handlers={},
            extra_opset=[],
            shape_override=None,
            input_names=inputs,
            output_names=outputs,
            inputs_as_nchw=None,
            outputs_as_nchw=None,
            large_model=False,
            tensors_to_rename=tensors_to_rename,
            ignore_default=None,
            use_default=None,
            tflite_path=None,
            dequantize=False,
            tfjs_path=None,
            initialized_tables=initialized_tables,
            output_frozen_graph=None,
            output_path=output_path)

onnx_model = os.path.join(model_dir, 'model.onnx')
convert(model_dir, onnx_model)

# Perform shape inference. Not doing this leads to a message saying:
# 2023-04-27 13:26:56.674926422 [E:onnxruntime:, inference_session.cc:1618 operator()] Exception during initialization: /code/onnxruntime/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1037 
# SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : 
# TensorRT input: Slice_b:0 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs
onnx_model_with_shapes = 'repro/model_with_shapes.onnx'
model = SymbolicShapeInference.infer_shapes(
    onnx.load(onnx_model),
    auto_merge=True,
    guess_output_rank=False,
)
onnx.save(model, onnx_model_with_shapes)
onnx_model = onnx_model_with_shapes

options = ort.SessionOptions()
options.log_severity_level = 0
options. log_verbosity_level = 0 
ort_sess = ort.InferenceSession(onnx_model, providers=[
    'TensorrtExecutionProvider',
    'CUDAExecutionProvider'
], sess_options=options)

Removing the TensorrtExecutionProvider works properly, i.e. this works:

options = ort.SessionOptions()
options.log_severity_level = 0
options. log_verbosity_level = 0 
ort_sess = ort.InferenceSession(onnx_model, providers=[
    'CUDAExecutionProvider'
], sess_options=options)
vdel commented 1 year ago

The initial archive was not exactly reproducing the bug. I have updated the message above with the correct Dockerfile.

The stack trace is the following:

#0  __GI_raise (sig=sig@entry=6) at ../sysdeps/unix/sysv/linux/raise.c:50
#1  0x00007f24c7952859 in __GI_abort () at abort.c:79
#2  0x00007f24c7952729 in __assert_fail_base (fmt=0x7f24c7ae8588 "%s%s%s:%u: %s%sAssertion `%s' failed.\n%n",
    assertion=0x7f23e8619dd8 "count.size() == 1 && \"implementation assumes 1D size of known size\"",
    file=0x7f23e86198e8 "/code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ShapeTensor.cpp", line=214, function=<optimized out>) at assert.c:92
#3  0x00007f24c7963fd6 in __GI___assert_fail (assertion=0x7f23e8619dd8 "count.size() == 1 && \"implementation assumes 1D size of known size\"",
    file=0x7f23e86198e8 "/code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ShapeTensor.cpp", line=214,
    function=0x7f23e8619d28 "onnx2trt::ShapeTensor onnx2trt::fillShapeVector(onnx2trt::IImporterContext*, int64_t, const onnx2trt::ShapeTensor&)") at assert.c:101
#4  0x00007f23e8489b7a in onnx2trt::fillShapeVector (ctx=0x50f6a560, value=1, count=...) at /code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ShapeTensor.cpp:214
#5  0x00007f23e8489a8b in onnx2trt::similar (ctx=0x50f6a560, exemplar=..., value=1) at /code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ShapeTensor.cpp:208
#6  0x00007f23e8421d6e in onnx2trt::(anonymous namespace)::importSlice (ctx=0x50f6a560, node=..., inputs=std::vector of length 4, capacity 4 = {...})
    at /code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/builtin_op_importers.cpp:4497
#7  0x00007f23e84452ee in std::_Function_handler<onnx2trt::ValueOrStatus<std::vector<onnx2trt::TensorOrWeights, std::allocator<onnx2trt::TensorOrWeights> > > (onnx2trt::IImporterContext*, onnx::NodeProto const&, std::vector<onnx2trt::TensorOrWeights, std::allocator<onnx2trt::TensorOrWeights> >&), onnx2trt::ValueOrStatus<std::vector<onnx2trt::TensorOrWeights, std::allocator<onnx2trt::TensorOrWeights> > > (*)(onnx2trt::IImporterContext*, onnx::NodeProto const&, std::vector<onnx2trt::TensorOrWeights, std::allocator<onnx2trt::TensorOrWeights> >&)>::_M_invoke(std::_Any_data const&, onnx2trt::IImporterContext*&&, onnx::NodeProto const&, std::vector<onnx2trt::TensorOrWeights, std::allocator<onnx2trt::TensorOrWeights> >&) (
    __functor=..., __args#0=@0x7ffdbe17c118: 0x50f6a560, __args#1=..., __args#2=std::vector of length 4, capacity 4 = {...}) at /usr/include/c++/9/bits/std_function.h:286
#8  0x00007f23e83c04d9 in std::function<onnx2trt::ValueOrStatus<std::vector<onnx2trt::TensorOrWeights, std::allocator<onnx2trt::TensorOrWeights> > > (onnx2trt::IImporterContext*, onnx::NodeProto const&, std::vector<onnx2trt::TensorOrWeights, std::allocator<onnx2trt::TensorOrWeights> >&)>::operator()(onnx2trt::IImporterContext*, onnx::NodeProto const&, std::vector<onnx2trt::TensorOrWeights, std::allocator<onnx2trt::TensorOrWeights> >&) const (this=0x75be238, __args#0=0x50f6a560, __args#1=..., __args#2=std::vector of length 4, capacity 4 = {...})
    at /usr/include/c++/9/bits/std_function.h:688
#9  0x00007f23e83b7231 in onnx2trt::parseGraph (ctx=0x50f6a560, graph=..., deserializingINetwork=false, currentNode=0x50f6a920)
    at /code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ModelImporter.cpp:174
#10 0x00007f23e83bb93e in onnx2trt::ModelImporter::importModel (this=0x50f6a520, model=...) at /code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ModelImporter.cpp:581
#11 0x00007f23e83bae45 in onnx2trt::ModelImporter::parseWithWeightDescriptors (this=0x50f6a520, serialized_onnx_model=0xf0332e0, serialized_onnx_model_size=4631)
    at /code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ModelImporter.cpp:519
#12 0x00007f23e83bb119 in onnx2trt::ModelImporter::parse (this=0x50f6a520, serialized_onnx_model=0xf0332e0, serialized_onnx_model_size=4631, model_path=0x0)
    at /code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ModelImporter.cpp:542
#13 0x00007f23e83ba545 in onnx2trt::ModelImporter::supportsModel (this=0x50f6a520, serialized_onnx_model=0xf0332e0, serialized_onnx_model_size=4631,
    sub_graph_collection=std::vector of length 0, capacity 0, model_path=0x8233640 "repro/model_with_shapes.onnx")
    at /code/onnxruntime/build/Linux/Debug/_deps/onnx_tensorrt-src/ModelImporter.cpp:398
#14 0x00007f23e83470d6 in onnxruntime::TensorrtExecutionProvider::GetSupportedList (this=0x8233400, nodes_vector_input=std::vector of length 1, capacity 1 = {...}, iterations=1,
    max_iterations=1000, graph=..., early_termination=0x7ffdbe17da5c) at /code/onnxruntime/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:929
#15 0x00007f23e83494f1 in onnxruntime::TensorrtExecutionProvider::GetCapability (this=0x8233400, graph=...)
    at /code/onnxruntime/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1117
#16 0x00007f23f93b223e in onnxruntime::<lambda(const onnxruntime::IExecutionProvider&, const onnxruntime::GraphViewer&, const onnxruntime::IExecutionProvider::IKernelLookup&)>::operator()(const onnxruntime::IExecutionProvider &, const onnxruntime::GraphViewer &, const onnxruntime::IExecutionProvider::IKernelLookup &) const (__closure=0x7ffdbe17ddde, ep=...,
    graph_viewer=..., kernel_lookup=...) at /code/onnxruntime/onnxruntime/core/framework/graph_partitioner.cc:147
#17 0x00007f23f93b258c in onnxruntime::GetCapabilityForEP (params=...) at /code/onnxruntime/onnxruntime/core/framework/graph_partitioner.cc:170
#18 0x00007f23f93b3681 in onnxruntime::PartitionOnnxFormatModelImpl (graph=..., func_mgr=..., kernel_registry_mgr=..., fused_kernel_registry=..., current_ep=...,
    mode=onnxruntime::GraphPartitioner::Mode::kNormal, fused_node_unique_id=@0x7ffdbe17e86c: 0, transform_layout_function=...)
    at /code/onnxruntime/onnxruntime/core/framework/graph_partitioner.cc:373
#19 0x00007f23f93b353d in onnxruntime::PartitionOnnxFormatModelImpl (graph=..., func_mgr=..., kernel_registry_mgr=..., fused_kernel_registry=..., current_ep=...,
    mode=onnxruntime::GraphPartitioner::Mode::kNormal, fused_node_unique_id=@0x7ffdbe17e86c: 0, transform_layout_function=...)
    at /code/onnxruntime/onnxruntime/core/framework/graph_partitioner.cc:347
#20 0x00007f23f93b49d5 in onnxruntime::PartitionOnnxFormatModel (partition_params=..., mode=onnxruntime::GraphPartitioner::Mode::kNormal, execution_providers=...,
    kernel_registry_manager=...) at /code/onnxruntime/onnxruntime/core/framework/graph_partitioner.cc:545
#21 0x00007f23f93b608a in onnxruntime::GraphPartitioner::Partition(onnxruntime::Graph&, onnxruntime::FuncManager&, std::function<onnxruntime::common::Status (onnxruntime::Graph&, bool&, onnxruntime::IExecutionProvider&)>, onnxruntime::GraphPartitioner::Mode) const (this=0x7ffdbe17e980, graph=..., func_mgr=..., transform_layout_function=...,
    mode=onnxruntime::GraphPartitioner::Mode::kNormal) at /code/onnxruntime/onnxruntime/core/framework/graph_partitioner.cc:730
#22 0x00007f23f88453f3 in onnxruntime::InferenceSession::TransformGraph (this=0x72664f0, graph=..., graph_transformer_mgr=..., providers=..., kernel_registry_manager=...,
    insert_cast_transformer=..., session_state=..., saving_model_in_ort_format=false) at /code/onnxruntime/onnxruntime/core/session/inference_session.cc:898
#23 0x00007f23f884963f in onnxruntime::InferenceSession::Initialize (this=0x72664f0) at /code/onnxruntime/onnxruntime/core/session/inference_session.cc:1409
vdel commented 1 year ago

I have simplified the reproduction and got rid of onnxruntime, I just use onnx-tensorrt and onnx 1.13.1. The stack trace is the same.

To reproduce: