facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.1k stars 7.42k forks source link

Export to onnx of a standard Detectron2 zoo faster-rcnn model generates a ReduceMax op not supported by ONNXRT TensorRT EP #4896

Open datinje opened 1 year ago

datinje commented 1 year ago

Instructions To Reproduce the 🐛 Bug:

  1. Full runnable code :
    
    `#!/usr/bin/env python
    # Copyright (c) Facebook, Inc. and its affiliates.
    # this is an adaptation of detectron2/tools/deploy/export_model.py
    # it does export of a faster-rcnn model to onnx and test it vs the original detectron2 model
    # requires any RGB input image (jpg or png)
    import argparse
    import os
    from typing import Dict, List, Tuple
    import torch
    from torch import Tensor, nn

import detectron2.data.transforms as T from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import build_detection_test_loader, detection_utils from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format from detectron2 import model_zoo """

cannot use detectron2 export lib since it depends on Caffe2 which is not provided anymore with pytorch dist

from detectron2.export import ( STABLE_ONNX_OPSET_VERSION, TracingAdapter, dump_torchscript_IR, scripting_with_instances, ) """

use export lib stripped out from caffe2 (/detectron2/export/init.py)

from lib.export import ( TracingAdapter, dump_torchscript_IR, scripting_with_instances, ) from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model from detectron2.modeling.postprocessing import detector_postprocess from detectron2.projects.point_rend import add_pointrend_config from detectron2.structures import Boxes from detectron2.utils.env import TORCH_VERSION from detectron2.utils.file_io import PathManager from detectron2.utils.logger import setup_logger

import onnx import onnxruntime as ort import numpy as np import cv2 as cv2

def setup_cfg(args):
cfg = get_cfg()

#use detectron2 satndard faster rcnn

cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml") 
cfg.MODEL.DEVICE = 'cuda'   

# cuda context is initialized before creating dataloader, so we don't fork anymore
cfg.DATALOADER.NUM_WORKERS = 0
add_pointrend_config(cfg)
cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml"))
cfg.merge_from_list(args.opts)
cfg.freeze()

return cfg

experimental. API not yet final

def export_tracing(torch_model, inputs): assert TORCH_VERSION >= (1, 8) image = inputs[0]["image"] inputs = [{"image": image}] # remove other unused keys

inference=None
"""
if isinstance(torch_model, GeneralizedRCNN):

    def inference(model, inputs):
        # use do_postprocess=False so it returns ROI mask
        inst = model.inference(inputs, do_postprocess=False)[0]
        return [{"instances": inst}]

else:
    inference = None  # assume that we just call the model directly
"""

traceable_model = TracingAdapter(torch_model, inputs, inference)

with PathManager.open(os.path.join(args.output, "faster_rcnn_fpn.onnx"), "wb") as f:
    torch.onnx.export(
          traceable_model, 
          (image,), 
          f, 
          do_constant_folding=True,
          export_params=True,
          input_names=["image"], # the model's input names
          output_names=["boxes", "labels", "scores", "image_dims"], # the model's output names
          dynamic_axes={
            "image"      : {1: "height", 2: "width"},
            "boxes"      : {0: "findings"}, # boxes is a tensor of shape [number of findings, 4] 
            "labels"     : {0: "findings"},
            "scores"     : {0: "findings"}
            },
          verbose=True, 
          opset_version=17) #issue is same with opset 16 and opset 18 is not validated for pytorch 2.0

logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
logger.info("Outputs schema: " + str(traceable_model.outputs_schema))

onnx_model_path = os.path.join(args.output, "faster_rcnn_fpn.onnx")
onnx_model = onnx.load(onnx_model_path)

return onnx_model

def get_sample_inputs(args):

if args.sample_image is None:
    # get a first batch from dataset
    data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
    first_batch = next(iter(data_loader))
    return first_batch
else:
    # get a sample data
    original_image = cv2.imread("./input.jpg")
    print ("original_image input shape :", original_image.shape)

    # Do same preprocessing as DefaultPredictor
    aug = T.ResizeShortestEdge(
        [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
    )
    image_with_different_size = aug.get_transform(original_image).apply_image(original_image)
    cv2.imwrite("./inputExpanded.jpg", image_with_different_size)

    image = original_image
    height, width = original_image.shape[:2]
    image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) # need chanel first for onnx
    print ("image input shape :", image.shape)

    inputs = {"image": image, "height": height, "width": width}

    # Sample ready
    sample_inputs = [inputs]
    return sample_inputs

def check_onnx_model (onnx_model):

Check the model

try: onnx.checker.check_model(onnx_model, full_check=True) except onnx.checker.ValidationError as e: print("The model is invalid: %s" % e) else: print("The model is valid!")

check the onnx graph

try: graph = onnx_model.graph onnx.checker.check_graph(graph) except onnx.checker.ValidationError as e: print("The graph is invalid: %s" % e) else: print("The graph is valid!")

input_shapes = [[d.dim_value for d in _input.type.tensor_type.shape.dim] for _input in onnx_model.graph.input] print ('onnx model input shapes', input_shapes)

return None

def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

def eval_onnx_model (torch_model, onnx_model, sample_inputs, args):

get D2 results

torch_model.eval() torch_outputs = torch_model(sample_inputs) print ('torch_outputs: ', torch_outputs) print ('torch size of outputs: ', len(torch_outputs))

t_outputs_scores = to_numpy(torch_outputs[0]['instances'].scores) print('d2_torch_scores: ', t_outputs_scores) t_outputs_boxes = to_numpy(torch_outputs[0]['instances'].pred_boxes.tensor) print('d2_torch_boxes: ', t_outputs_boxes) t_outputs_classes = to_numpy(torch_outputs[0]['instances'].pred_classes) print('d2_torch_classes: ', t_outputs_classes) print('')

get ONNXRT results

onnx_model_path = os.path.join(args.output, "faster_rcnn_fpn.onnx") providers = [('TensorrtExecutionProvider')]

providers = [('CUDAExecutionProvider')] # works !

sess_opt = ort.SessionOptions() sess = ort.InferenceSession(onnx_model_path, sess_options=sess_opt, providers=providers)

input_name = sess.get_inputs()[0].name print("input name", input_name) input_shape = sess.get_inputs()[0].shape print("input shape", input_shape) input_type = sess.get_inputs()[0].type print("input type", input_type)

output_name = sess.get_outputs()[0].name print("output name", output_name) output_shape = sess.get_outputs()[0].shape print("output shape", output_shape) output_type = sess.get_outputs()[0].type print("output type", output_type)

image = sample_inputs[0]['image'] np_image = image.cpu().numpy()

compute ONNX Runtime output prediction

ort_inputs = {sess.get_inputs()[0].name: np_image} ort_outputs = sess.run(None, ort_inputs)

print ('ort_outputs: ', ort_outputs) print('ort_outputs number: ', len(ort_outputs)) print('')

boxes = ort_outputs[0] classes = ort_outputs[1] scores = ort_outputs[2]

print ('ort_boxes : ', boxes) print ('ort scores : ', scores) print ('ort classes : ', classes) print('')

eval torch and onnxrt outputs

np.testing.assert_allclose(t_outputs_boxes, boxes, rtol=1e-03, atol=1e-05) np.testing.assert_allclose(t_outputs_scores, scores, rtol=1e-03, atol=1e-05) np.testing.assert_allclose(t_outputs_classes, classes, rtol=1e-03, atol=1e-05) print('detectron2 torch and onnx models results match!') print('')

return None

if name == "main": parser = argparse.ArgumentParser(description="Export a model for deployment.") parser.add_argument("--sample-image", default=None, type=str, help="sample image for input") parser.add_argument("--output", help="output directory for the converted model") parser.add_argument( "opts", help="Modify config options using the command-line", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() logger = setup_logger() logger.info("Command line arguments: " + str(args))

PathManager.mkdirs(args.output)

cfg = setup_cfg(args)

# create a torch model
torch_model = build_model(cfg)
DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS)
torch_model.eval()

# convert and save model
sample_inputs = get_sample_inputs(args)
onnx_model = export_tracing(torch_model, sample_inputs)

check_onnx_model (onnx_model)

eval_onnx_model(torch_model, onnx_model, sample_inputs, args)

logger.info("Success.")`
2. What exact command you run:
`python3 export_model.py --output onnx_output --sample-image input.jpg`

4. __Full logs__ or other relevant observations:

[04/04 16:14:53 detectron2]: Command line arguments: Namespace(sample_image='input.jpg', output='onnx_output', opts=[]) original_image input shape : (480, 640, 3) image input shape : torch.Size([3, 480, 640])

%/model/ReduceMax_output_0 : Long(2, strides=[1], requires_grad=0, device=cpu) = onnx::ReduceMax[axes=[0], keepdims=0, onnx_name="/model/ReduceMax"](%/model/Concat_1_output_0), scope: lib.export.flatten.TracingAdapter::/detectron2.modeling.meta_arch.rcnn.GeneralizedRCNN::model # /usr/local/lib/python3.10/dist-packages/detectron2/structures/image_list.py:83:0

  %max_coordinate.3 : Float(device=cpu) = **onnx::ReduceMax[keepdims=0]**(%/model/roi_heads/Cast_9_output_0) # /usr/local/lib/python3.10/dist-packages/**torchvision**/ops/boxes.py:91:21

============= Diagnostic Run torch.onnx.export version 2.0.0+cu118 ============= verbose: False, log level: Level.ERROR ======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

[04/04 16:15:02 detectron2]: Inputs schema: TupleSchema(schemas=[ListSchema(schemas=[DictSchema(schemas=[IdentitySchema()], sizes=[1], keys=['image'])], sizes=[1])], sizes=[1]) [04/04 16:15:02 detectron2]: Outputs schema: ListSchema(schemas=[DictSchema(schemas=[InstancesSchema(schemas=[TensorWrapSchema(class_name='detectron2.structures.Boxes'), IdentitySchema(), IdentitySchema()], sizes=[1, 1, 1], keys=['pred_boxes', 'pred_classes', 'scores'])], sizes=[4], keys=['instances'])], sizes=[4]) The model is valid!

The graph is invalid: Unrecognized attribute: axes for operator ReduceMax ==> Context: Bad node spec for node. Name: /model/ReduceMax OpType: ReduceMax onnx model input shapes [[3, 0, 0]]

2023-04-04 16:37:52.173723690 [E:onnxruntime:Default, tensorrt_execution_provider.h:61 log] [2023-04-04 16:37:52 ERROR] ReduceMax_1597: at least 1 dimensions are required for input. 2023-04-04 16:37:52.324418966 [E:onnxruntime:, inference_session.cc:1532 operator()] Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:897 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : TensorRT input: /model/proposal_generator/GatherND_2_output_0 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs

Traceback (most recent call last): File "/cad-engine/export_model.py", line 264, in eval_onnx_model(torch_model, onnx_model, sample_inputs, args) File "/cad-engine/export_model.py", line 190, in eval_onnx_model sess = ort.InferenceSession(onnx_model_path, sess_options=sess_opt, providers=providers) File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 360, in init self._create_inference_session(providers, provider_options, disabled_optimizers) File "/usr/local/lib/python3.10/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 408, in _create_inference_session sess.initialize_session(providers, provider_options, disabled_optimizers) onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Exception during initialization: /onnxruntime_src/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:897 SubGraphCollection_t onnxruntime::TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t, int, int, const onnxruntime::GraphViewer&, bool*) const [ONNXRuntimeError] : 1 : FAIL : TensorRT input: /model/proposal_generator/GatherND_2_output_0 has no shape specified. Please run shape inference on the onnx model first. Details can be found in https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs

5. please simplify the steps as much as possible so they do not require additional resources to
   run, such as a private dataset.
Unfortunately , requires an input RGB png or jpeg image (unless can randomize the input in teh code above)

## Expected behavior:
with TensorRT Execution provider the code above should work as fine as with the CUDAExecutionProvider (or the CPUExecutionProvider)
That means that the detectron2 export to onnx should generate a onnx::ReduceMax call with no axes argument

## Environment:
PyTorch version: 2.0.0+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.35

Python version: 3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-5.3.18-150300.59.63-default-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Quadro RTX 8000
Nvidia driver version: 515.76
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   45 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          16
On-line CPU(s) list:             0-15
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) Gold 6258R CPU @ 2.70GHz
CPU family:                      6
Model:                           85
Thread(s) per core:              1
Core(s) per socket:              8
Socket(s):                       2
Stepping:                        7
BogoMIPS:                        5387.34
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon nopl xtopology tsc_reliable nonstop_tsc cpuid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Hypervisor vendor:               VMware
Virtualization type:             full
L1d cache:                       512 KiB (16 instances)
L1i cache:                       512 KiB (16 instances)
L2 cache:                        16 MiB (16 instances)
L3 cache:                        77 MiB (2 instances)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-15
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.24.2
[pip3] torch==2.0.0+cu118
[pip3] torchaudio==2.0.1+cu118
[pip3] torchvision==0.15.1+cu118
[pip3] triton==2.0.0
[conda] Could not collect

My understanding detectron2 code have torch.export.onnx generate a wrong call to onnx::ReduceMax operator which does not support the axes parameter besides the keepdims parameter Unlike torchvison code (see logs)

Please note :

  1. onnx check_model passes
  2. onnx check_graph does not passes : it detects the same issue on ReduceMax than TensorRT
  3. the generated ONNX model works (with same results as D2 model) on CUDA EP and CPU EP
  4. the https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools)/symbolic_shape_infer.py tools suggested in onnrt and TRT EP error dump does not work : it crashes on the same generated ONNX model ReduceMax operator

This issue is unfortunate because TensorRT EP is needed for subsequent optimizations like FP16 or Automatic Mixed Precision to benefit from Nvidia tensor cores.

Byte247 commented 1 year ago

I could reproduce your problem under Ubuntu 20.04. I had to change the order loading the weights to cfg.merge_from_file(model_zoo.get_config_file("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml")) and then: fg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Detection/faster_rcnn_R_50_FPN_1x.yaml") to get it running, as well as the import of the TracingAdapter. Without the TensorrtExecutionProvider I now get these timings which not great: time detectron2: 0.0796477980002237 time onnx: 1.6388183159997425

datinje commented 1 year ago

I did not get any issue with either the CPUExecutionProvider nor the CUDAExecutionProvider. Onnxrt was even faster by 30% than D2 on cpu only . I did not check cuda perf time but was pretty fast. The main problem is that tensort does not accept the reducemax instruction on the onnx model and that I could not either split the model into subgraph to assign the subgraph with faulty instruction to cpu. So we are blocked with tensorrt. This D2 zoo model is just a way for me demo the PB , but my main goal is to fix my own app model also faster rcnn based which has the same issue and which blocks it's productization . Alternative is to rewrite the app with native pytorch or try another ep maybe TVM . Did you experience the same issue with tensorrt EP?

Byte247 commented 1 year ago

Yes with the tensorrt EP I get the exact same error

datinje commented 1 year ago

Good. We are on the same page. Is there a workaround if not real fix a available ?

Byte247 commented 1 year ago

My workaround is to not use the ONNX runtime, but instead use plain TensorRT. This tutorial is a good start: https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples/python/detectron2. From here you can convert a Faster R-CNN to TensorRT aswell by removing the Mask head. Also the rectangular requirement in the tutorial is not necessary as long as your sides are divisible by 32.

datinje commented 1 year ago

My company wants than the implementation be portable to cpu/GPU and to any GPU vendor. Hence the choice of onnxrt. I cannot use native try as a workaround. The current workaround is to use the cuda EP. But I can't benefit from TRT EP model optimisation and Nvidia tesnor core acceleration via TRT EP use of fp16 or automatic mixes precision options. I really need this to be fixed from detectron2. I am going to check if apache TVM EP for onnxrt is supporting the generated onnx format like the cuda EP. The final alternay would be to rewrite everything in pure pytorch without detectron2 . Il would like to avoid that. Btw I saw no new D2 version since Nov 2021. Is Detectron2 still supported?

datinje commented 1 year ago

any clue of what is the problem due to and what workaround I could use (other than using plain TensorRT) - I am trying the TVM EP to see if the TVM accepts the ONNX faulty instruction generated by Dectectron2

jcdatin commented 1 year ago

any update on this ?

datinje commented 1 year ago

what is the next step ?

datinje commented 1 year ago

Apparently , the problems is similar to tensorRt not supporting Mask-rcnn model . see https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples/python/detectron2 TensorRt does not support managing variable output size in the model (that results in a generating by pytorch/detectron2 some ONNX instruction with different arguments than TensorRT supports) and so Nvidia engineers had to produce a python transformer for creating another ONNX model where the variable size is managed with a graph of box proposals. But the script is specific to mask-rcnn and so we need to to create the same for faster-rcnn models. (names of the nodes could be different). An alternative would be to rerelease TensorRt with a fix for the new ONNX instructions . Starting discussion with Nvidia. Keep in touch (is there anyone interested ? else I can ask to close the case as the solution would not come via detectron2 modification ?)

datinje commented 1 year ago

Problem was narrowed down to ONNXRT TensorRT Execution Provider implementation which is under ONNXRuntime maintainers (Microsoft). This is because -no issue with ORT CPU execution provider -no issue with ORT Cuda Execution provider -no issue with tensorRT native runtime after convering ONNX to TensoRT model native format , see:

trtexec --onnx=output/faster_rcnn_fpn.onnx --verbose

[06/30/2023-15:15:35] [I] TensorRT version: 8.6.1 [06/30/2023-15:15:52] [I] Start parsing network model. [06/30/2023-15:15:52] [I] [TRT] ---------------------------------------------------------------- [06/30/2023-15:15:52] [I] [TRT] Input filename: output/faster_rcnn_fpn.onnx [06/30/2023-15:15:52] [I] [TRT] ONNX IR version: 0.0.8 [06/30/2023-15:15:52] [I] [TRT] Opset version: 16 [06/30/2023-15:15:52] [I] [TRT] Producer name: pytorch [06/30/2023-15:15:52] [I] [TRT] Producer version: 2.0.1 [06/30/2023-15:15:52] [I] [TRT] Domain:
[06/30/2023-15:15:52] [I] [TRT] Model version: 0 [06/30/2023-15:15:52] [I] [TRT] Doc string:
[06/30/2023-15:15:52] [I] [TRT] ---------------------------------------------------------------- [06/30/2023-15:15:34] [I] === Model Options === [06/30/2023-15:15:34] [I] Format: ONNX [06/30/2023-15:15:34] [I] Model: output/faster_rcnn_fpn.onnx [06/30/2023-15:15:34] [I] Output: .... [06/30/2023-15:15:53] [V] [TRT] Parsing node: /model/ReduceMax [ReduceMax] [06/30/2023-15:15:53] [V] [TRT] Searching for input: /model/Concat_1_output_0 [06/30/2023-15:15:53] [V] [TRT] /model/ReduceMax [ReduceMax] inputs: [/model/Concat_1_output_0 -> (1, 2)[INT32]], [06/30/2023-15:15:53] [V] [TRT] Registering layer: /model/ReduceMax for ONNX node: /model/ReduceMax [06/30/2023-15:15:53] [V] [TRT] Registering tensor: /model/ReduceMax_output_0 for ONNX tensor: /model/ReduceMax_output_0 [06/30/2023-15:15:53] [V] [TRT] /model/ReduceMax [ReduceMax] outputs: [/model/ReduceMax_output_0 -> (2)[INT32]], .....

So ORT TRT EP implementation must be fixed first.

Second, there seems to be other issues in theD2 faster-rcnn model ONNX graph to be fixed after. : [06/30/2023-15:15:53] [E] [TRT] ModelImporter.cpp:777: ERROR: ModelImporter.cpp:195 In function parseGraph: [6] Invalid Node - /model/roi_heads/pooler/level_poolers.0/If /model/roi_heads/pooler/level_poolers.0/If_OutputLayer: IIfConditionalOutputLayer inputs must have the same shape. Shapes are [0] and [0,1]. [06/30/2023-15:15:53] [E] Failed to parse onnx file [06/30/2023-15:15:53] [I] Finished parsing network model. Parse time: 0.654328 [06/30/2023-15:15:53] [E] Parsing model failed [06/30/2023-15:15:53] [E] Failed to create engine from model or file. [06/30/2023-15:15:53] [E] Engine set up failed &&&& FAILED TensorRT.trtexec [TensorRT v8601] # trtexec --onnx=output/faster_rcnn_fpn.onnx --verbose

according some nice Nvidia engineer I talked to : this is a normal phase to do in running an ONNX model into tensoRT . This is because to be able to nicely optimize the graph , tensort has stricter requirements not supported by ONNX . So we have to use some nvida tool like graphSurgean to transform the ONNX model into another ONNX model that is supported by TensoRT. This was done for example with the mask-rcnn ONNX model , and have yet to be done for the faster-rcnn model.

jywu-msft commented 1 year ago

Hi, our team works on ORT TRT EP. Would it be possible for you to create a new issue in https://github.com/microsoft/onnxruntime/issues with the latest details? We will have someone take a closer look. I'm not sure I fully understand your last comment here. It sounds like you are saying the updated onnx model works in native TensorRT? but you also show some output showing trtexec failing? Let's continue the discussion and follow up in onnxruntime github issue.

datinje commented 1 year ago

We have been working for a while with the onnxruntime team to try to fix this issue. see https://github.com/microsoft/onnxruntime/issues/16886

The problem comes from Some pytorch modelule creating a ReduceMax instruction with onn opset 13 and detectron2 creating another ReduceMax instruction with latest API (as od opset18) - despite D2 model is converted using opset 17 (opset16 does no better).

Eventually we could make the faster-rcnn model from D2 model zoo work using TWO converter : onnxsim and symbolic_shape_infer . But that did not solve my own D2 model . The problem only hid another problem ( Squeez(13) ort trt EP kernel missing. The only way to fix this was to split the graph into subgraphs where (most) of the subgraph not supported by TRT EP wer fall back to CUDA EP yielding a terrible inference performance.

I have to resolve to either go with CUDA EP only or completely rewrite the D2 modules in pure pytorch.

Nobody answerd as to whter D2 was a good option to use. Due to all my onnx deployment problems I would not recommand.

xinlin-xiao commented 9 months ago

I also use the configuration file of the project EVA-02/det/ built on Detectron2: 'https://github.com/baaivision/EVA/blob/master/EVA-02/det/projects/ViTDet/configs/eva2_o365_to_coco/eva2_o365_to_coco_cascade_mask_rcnn_vitdet_l_8attn_1536_lrd0p8 .py' try to export it to onnx, after my modification to export_model.py using lazy-config, I can use the command "/export_model.py --config-file /mnt/data1/download_new/EVA/EVA-master- Project-lazy/Eva-02/DET/Projects/Vitdet/Configs/EVA2_O365_COCO/EVA2_TO_COCADE_MASK_RCNN_L_L_8ATTN_1536_LRRD0P8 .py-OUTPUT OUTPUT/ TRT_12.18/ --sexport-Method Tracing-Format Onnx "successfully exported onnx, then use onnxslm:" https://github.com/WeLoveAI/OnnxSlim" Simplify the model, and then use tensor's onnx parser to try to convert onnx to trt,

it returned an error: "[12/20/2023-11:00:01] [E] [TRT] 4: [shapeCompiler.cpp::evaluateShapeChecks::1180] Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: If_9464_OutputLayer: dimensions not compatible for if-conditional outputs Condition '==' violated: 0 != 1250.) [12/20/2023-11:00:01] [E] [TRT] 2: [builder.cpp::buildSerializedNetwork::751] Error Code 2: Internal Error (Assertion engine != nullptr failed. )",

I found that the part where the error was reported contained the operator nms in the onnx in netron diagram,

before sim OOBQA5$M}J1 V `S81M$AV GMG5V0} G_$MK2XSS WT1V5 K65J@Z30PC3XO_VLR25L)_P afeter sim K65J@Z30PC3XO_VLR25L)_P 1703055812972

and then part of my model structure was: "(proposal_generator): RPN( (rpn_head): StandardRPNHead( (conv): Sequential( (conv0): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) (activation): ReLU() ) (conv1): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) (activation): ReLU() ) ) (objectness_logits): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) (anchor_deltas): Conv2d(256, 12, kernel_size=(1, 1), stride=(1, 1)) ) (anchor_generator): DefaultAnchorGenerator( (cell_anchors): BufferList() ) ) (roi_heads): CascadeROIHeads( (box_pooler): ROIPooler( (level_poolers): ModuleList( (0): ROIAlign(output_size=(7, 7), spatial_scale=0.25, sampling_ratio=0, aligned=True) (1): ROIAlign(output_size=(7, 7), spatial_scale=0.125, sampling_ratio=0, aligned=True) (2): ROIAlign(output_size=(7, 7), spatial_scale=0.0625, sampling_ratio=0, aligned=True) (3): ROIAlign(output_size=(7, 7), spatial_scale=0.03125, sampling_ratio=0, aligned=True) ) ) (box_head): ModuleList( (0-2): 3 x FastRCNNConvFCHead( (conv1): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): LayerNorm() (activation): ReLU() ) (conv2): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): LayerNorm() (activation): ReLU() ) (conv3): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): LayerNorm() (activation): ReLU() ) (conv4): Conv2d( 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False (norm): LayerNorm() (activation): ReLU() ) (flatten): Flatten(start_dim=1, end_dim=-1) (fc1): Linear(in_features=12544, out_features=1024, bias=True) (fc_relu1): ReLU() ) ) (box_predictor): ModuleList( (0-2): 3 x FastRCNNOutputLayers( (cls_score): Linear(in_features=1024, out_features=175, bias=True) (bbox_pred): Linear(in_features=1024, out_features=4, bias=True) ) ) ", which uses some "FastRCNNConvFCHead ” and some structures of “RPN and CascadeROIHeads”, so if you can give me some suggestions and ideas to fix this error, I will be very grateful to you!