Closed tak-ho closed 1 year ago
Hi, what about the result under dynamic config?
Hi, what about the result under dynamic config?
[3.2238083, 357.37311, 151.36804, 497.31665, 0.078171365]
And the dynamic config is something like this:
codebase_config = dict(
type='mmdet',
task='ObjectDetection',
model_type='end2end',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005, # for YOLOv3
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
)
)
onnx_config = dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file='dyhead_swin_dynamic_1024.onnx',
input_names=['input'],
output_names=['dets', 'labels'],
input_shape=None,
optimize=True,
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'dets': {
0: 'batch',
1: 'num_dets',
},
'labels': {
0: 'batch',
1: 'num_dets',
},
}
)
backend_config = dict(
type='tensorrt',
common_config=dict(fp16_mode=False, max_workspace_size=8 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(min_shape=[1, 3, 1024, 1024],
opt_shape=[1, 3, 1024, 1024],
max_shape=[1, 3, 1024, 1024]
)
)
)
]
)
Here is the code that I use to predict using tensorrt:
import pycuda.autoinit
import tensorrt as trt
from mmdeploy.backend.tensorrt.init_plugins import load_tensorrt_plugin
load_tensorrt_plugin()
import cv2
import numpy as np
# All are from the tensorrt examples
from inference_server.engine import build_engine, save_engine, load_engine, get_engine_in_out_shape
from inference_server.common import allocate_buffers, do_inference_v2
engine = load_engine('/workdir/dyhead_swin_dynamic_1024.engine')
ec = engine.create_execution_context()
inputs, outputs, bindings, stream = allocate_buffers(engine)
engine_out_shape = get_engine_in_out_shape(engine)
with open('/workdir/test-960.jpg', 'rb') as f:
im = f.read()
batch_data = np.zeros((1, 3, 1024, 1024), dtype=np.float32)
from mmcv.image.io import imfrombytes
cv_im = imfrombytes(im)
cv_im = np.array(cv_im, dtype=np.uint8, order='C')
from mmdet.datasets.pipelines.transforms import Pad, Resize, Normalize
# same as the testing pipeline in the model config
im_res = Resize(img_scale=(960,960), keep_ratio=True, backend='pillow')({'img':cv_im})
im_res = Normalize([123.675, 116.28, 103.53], [58.395, 57.12, 57.375], to_rgb=True)(im_res)
im_res = Pad(size_divisor=128)(im_res)
nor_cv_im = im_res['img'].transpose((2, 0, 1))
batch_data[0] = nor_cv_im
inputs[0].host = np.ascontiguousarray(batch_data)
res = do_inference_v2(ec, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
Could you please provide me with the files you have used?
/workdir/detection_onnx_static_1024x1024.py
/workdir/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x.py
/workdir/latest.pth
.
Are they different from our official config and checkpoints?
/workdir/detection_onnx_static_1024x1024.py is the one I posted above, let me post it again
codebase_config = dict(
type='mmdet',
task='ObjectDetection',
model_type='end2end',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005, # for YOLOv3
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
)
)
onnx_config = dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file='dyhead_swin_1024.onnx',
input_names=['input'],
output_names=['dets', 'labels'],
input_shape=[1024, 1024],
optimize=True
)
backend_config = dict(
type='tensorrt',
common_config=dict(fp16_mode=False, max_workspace_size=8 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(min_shape=[1, 3, 1024, 1024],
opt_shape=[1, 3, 1024, 1024],
max_shape=[1, 3, 1024, 1024]
)
)
)
]
)
/workdir/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x.py:
checkpoint_config = dict(interval=1)
log_config = dict(interval=50, hooks=[dict(type='TextLoggerHook')])
custom_hooks = [dict(type='NumClassCheckHook')]
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
opencv_num_threads = 0
mp_start_method = 'fork'
auto_scale_lr = dict(enable=False, base_batch_size=16)
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth'
model = dict(
type='ATSS',
backbone=dict(
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
with_cp=False,
convert_weights=True,
init_cfg=dict(
type='Pretrained',
checkpoint=
'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth'
)),
neck=[
dict(
type='FPN',
in_channels=[384, 768, 1536],
out_channels=256,
start_level=0,
add_extra_convs='on_output',
num_outs=5),
dict(
type='DyHead',
in_channels=256,
out_channels=256,
num_blocks=6,
zero_init_offset=False)
],
bbox_head=dict(
type='ATSSHead',
num_classes=1,
in_channels=256,
pred_kernel_size=1,
stacked_convs=0,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128],
center_offset=0.5),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0.0, 0.0, 0.0, 0.0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
dataset_type = 'CocoDataset'
classes = ('abc', )
data_root = '/mnt/data/dataset/abc-detection/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data = dict(
samples_per_gpu=1,
workers_per_gpu=1,
train=dict(
type='RepeatDataset',
times=2,
dataset=dict(
type='CocoDataset',
classes=('abc', ),
filter_empty_gt=False,
ann_file=
'/mnt/data/dataset/abc-detection/labels/coco.json',
img_prefix=
'/mnt/data/dataset/abc-detection/images/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(960, 960)],
multiscale_mode='range',
keep_ratio=True,
backend='pillow'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(
type='RandomAffine',
max_rotate_degree=10,
max_translate_ratio=0.1,
max_shear_degree=2,
scaling_ratio_range=(0.9, 1.1)),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=128),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
])),
val=dict(
type='CocoDataset',
classes=('abc', ),
filter_empty_gt=False,
ann_file=
'/mnt/data/dataset/abc-detection/labels/coco.json',
img_prefix=
'/mnt/data/dataset/abc-detection/images/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(960, 960),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True, backend='pillow'),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=128),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
])
]),
test=dict(
type='CocoDataset',
classes=('abc', ),
filter_empty_gt=False,
ann_file=
'/mnt/data/dataset/abc-detection/labels/coco.json',
img_prefix=
'/mnt/data/dataset/abc-detection/images/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(960, 960),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True, backend='pillow'),
dict(type='RandomFlip'),
dict(
type='Normalize',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=128),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
])
]))
evaluation = dict(interval=1, metric='bbox')
optimizer_config = dict(grad_clip=None)
optimizer = dict(
type='AdamW',
lr=5e-05,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg=dict(
custom_keys=dict(
absolute_pos_embed=dict(decay_mult=0.0),
relative_position_bias_table=dict(decay_mult=0.0),
norm=dict(decay_mult=0.0))))
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
work_dir = '/mnt/data/ai-exp/abc-detection'
auto_resume = False
gpu_ids = [0]
/workdir/latest.pth is too large to post in here... (around 2.4GB) the main difference for the official config is instead of original coco classes, I train it to detect single class abc and the imgscale size have change. and I use the official checkpoint as pre-trained and train with single class detection
please kindly help!
I am using mmdeploy master branch (21775ce584cf9589ebade133e46d008ee60acf9d), I don't know if this information helps... and I suspect the score will be a lot low too if using official provided check point and config. Does mmdeploy not supporting Dyhead at this stage?
Thanks
The original Dyhead config and the pretrained model is in this page https://github.com/open-mmlab/mmdetection/tree/master/configs/dyhead The style is caffe I don't know whether it will lead to any error.
Do I need to provide any more information @hanrui1sensetime ? Thanks!!
are you solve it ? I convert another model and have the same problem
are you solve it ? I convert another model and have the same problem
I still waiting for the reply, haven't solve it yet...
@hanrui1sensetime I have test again using the official config and checkpoint, they do have different result.
deploy config: https://workdrive.zohoexternal.com/external/0397370f2292821f4995ea6051201401c2538be4fb6aea2f19c796c07be399f7 model config: https://github.com/open-mmlab/mmdetection/blob/master/configs/dyhead/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco.py model checkpoint/weight: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco_20220509_100315-bc5b6516.pth testing image: https://workdrive.zohoexternal.com/external/74105a80a4d3bfb2edebffc3f2cded21b8d348a07ef3b6e103ef00d7b755637f
The result are different, highest score among detections from the tensorrt engine is 0.22415251 but the mmdet highest score among detections is 0.71841
Could you please kindly have a look? Thanks
@lvhan028 @hanrui1sensetime I think this comment (https://github.com/open-mmlab/mmdeploy/issues/972#issuecomment-1250529149) can already reproduce the issue. Please let me know if I need to provide any extra information for investigating? Thanks!
This model is too big to debug(2h for convert), so it uses me much time.
The reason is because DyDCNv2
has different shape among x
, offset
and mask
when using self.spatial_conv_high
and self.spatial_conv_low
, which will lead wrong result in TensorRT.
Thanks for the investigation! Then should I just wait for the fix or there is some workaround I could do on the config side?
ation! Then should I just wait for the fix or there is some workaround I could do on the config side?
The official config of mmdeploy still have this problem, so it is no need to change deploy config now.
This code below can minimum reproduce your bug below about why onnxruntime works fine and trt drops score. If you want to run ort backend, please change the comments of deploy config in the example code.
from mmdet.models.necks.dyhead import DyHeadBlock
import torch
import torch.nn.functional as F
from mmdeploy.codebase import import_codebase
from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import (WrapFunction, WrapModel, backend_checker,
check_backend, get_model_outputs,
get_onnx_model, get_rewrite_outputs)
from mmengine import Config
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.spatial_conv_high = DyHeadBlock(256, 256).spatial_conv_high
def forward(self, x, offset, mask):
return self.spatial_conv_high(x, offset, mask)
model = TestModule()
model.requires_grad_(False)
# x = torch.load('x1.pt')
x = torch.rand(1, 256, 50, 76)
# offset = torch.load('offset.pt')
offset = torch.rand(1, 256, 100, 152)
# mask = torch.load('mask.pt')
mask = torch.rand(1, 256, 100, 152)
print(f'debugging x: {x}, offset: {offset}, mask: {mask}')
model_outputs = model.forward(x, offset, mask)
print(f'model output: {model_outputs}, shape: {model_outputs.shape}')
wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {
'x': x,
'offset': offset,
'mask': mask
}
''' tensorrt config, please change it manually.
deploy_cfg = Config(
dict(
onnx_config=dict(output_names=None, input_shape=None),
backend_config=dict(
type='tensorrt',
common_config=dict(
fp16_mode=False, max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 1024, 1024],
max_shape=[1, 256, 1024, 1024]
)
)
)
]),
codebase_config=dict(
type='mmdet',
task='ObjectDetection')))
'''
deploy_cfg = Config(
dict(
onnx_config=dict(output_names=None, input_shape=None),
backend_config=dict(
type='onnxruntime'),
codebase_config=dict(
type='mmdet',
task='ObjectDetection')))
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
print(f'rewrite_outputs: {rewrite_outputs}, rewrite_outputs[0].shape: {rewrite_outputs[0].shape}')
Hi there, I'm facing the same issue. Is there any solution?
@hanrui1sensetime Thanks for the onnx workaround, but I think I will need to wait for the fix for tensorrt version
Hi there, I'm facing the same issue. Is there any solution?
The solution is to fix the implementation of this.
I've compared modulated_deform_conv
implementation for backends onnxruntime and tensorrt and found the discrepancy:
while in onnxruntime backend implementation https://github.com/open-mmlab/mmdeploy/blob/master/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp#:~:text=void%20deformable_conv2d_ref_fp32 offset and mask steps are calculated using output tensor shape
deformable_im2col_2d<float>(
src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w,
offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w,
mask + b * offset_group * kernel_h * kernel_w * dst_h * dst_w,
src_h, src_w, kernel_h,
kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, ic_per_gp,
offset_group, dst_h, dst_w, mask != nullptr, columns);
the implementation for tensorrt https://github.com/open-mmlab/mmdeploy/blob/master/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu#:~:text=typename%20scalar_t%3E-,void%20ModulatedDeformConvForwardCUDAKernelLauncher,-( uses input tensor shape:
const size_t offset_step = deformable_group * kernel_h * kernel_w * 2 * height * width;
const size_t mask_step = deformable_group * kernel_h * kernel_w * height * width;
I've tried to use output tensort shape in tensorrt implementation (height_out and width_out), re-built the library and model engine. But this didn't solve the issue.
@vedrusss I do the same things as you did but no luck...
I found that before going inside ModulatedDeformConvForwardCUDAKernelLauncher (trt_modulated_deform_conv_kernel.cu)
, in the mmdeploy/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp
, the outputDesc (outputDesc[0].dims.d[2], outputDesc[0].dims.d[3]
) already indicated the output height width is [100, 152] instead of the correct [50, 76]
is that normal?
Is there anywhere math description of what algorithm should be implemented there? @hanrui1sensetime , @tak-ho-raspect ?
outputDesc[0].dims.d[2]
Looks like outputDesc[0]
is not passed to the ModulatedDeformConvForwardCUDAKernelLauncher
as is, only its one element int channels_out = outputDesc[0].dims.d[1]
.
BTW where from do you get correct output height and width values? @tak-ho-raspect
@vedrusss I am using the sample code provided and x is 50, 76
x = torch.rand(1, 256, 50, 76)
the output size from onnx result are same as the input size, so I just guess the "correct" height width should be the same with the input
@tak-ho-raspect , can you provide your sample code to run onnx model?
I've found a mention about onnx implementation of MMCVDeformConv2d here. The author writes the implementation of MMCVModulatedDeformConv2d differs from MMCVDeformConv2d. Maybe this is a reason of discrepancy in detections?
@vedrusss I use the code provided by @hanrui1sensetime
from turtle import forward
from mmdet.models.necks.dyhead import DyHeadBlock
import torch
import torch.nn.functional as F
from mmdeploy.codebase import import_codebase
from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import (WrapFunction, WrapModel, backend_checker,
check_backend, get_model_outputs,
get_onnx_model, get_rewrite_outputs)
from mmengine import Config
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.spatial_conv_high = DyHeadBlock(256, 256).spatial_conv_high
def forward(self, x, offset, mask):
return self.spatial_conv_high(x, offset, mask)
model = TestModule()
model.requires_grad_(False)
# x = torch.load('x1.pt')
x = torch.rand(1, 256, 50, 76)
# offset = torch.load('offset.pt')
offset = torch.rand(1, 256, 100, 152)
# mask = torch.load('mask.pt')
mask = torch.rand(1, 256, 100, 152)
print(f'debugging x: {x}, offset: {offset}, mask: {mask}')
model_outputs = model.forward(x, offset, mask)
print(f'model output: {model_outputs}, shape: {model_outputs.shape}')
wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {
'x': x,
'offset': offset,
'mask': mask
}
''' tensorrt config, please change it manually.
deploy_cfg = Config(
dict(
onnx_config=dict(output_names=None, input_shape=None),
backend_config=dict(
type='tensorrt',
common_config=dict(
fp16_mode=False, max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(min_shape=[1, 3, 32, 32],
opt_shape=[1, 3, 1024, 1024],
max_shape=[1, 256, 1024, 1024]
)
)
)
]),
codebase_config=dict(
type='mmdet',
task='ObjectDetection')))
'''
deploy_cfg = Config(
dict(
onnx_config=dict(output_names=None, input_shape=None),
backend_config=dict(
type='onnxruntime'),
codebase_config=dict(
type='mmdet',
task='ObjectDetection')))
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
print(f'rewrite_outputs: {rewrite_outputs}, rewrite_outputs[0].shape: {rewrite_outputs[0].shape}')
Are there any updates?
Are there any updates?
We will fix the TRTModulatedDeformConv2d kernel later.
@tak-ho-raspect, @hanrui1sensetime , I've found the discrepancy between rewrite_outputs
(in test code upper) disappears if here one replaces index from 1 (taking offsets shape as output shape) to 0 (taking input shape as output shape):
ret.d[2] = inputs[0].d[2];
ret.d[3] = inputs[0].d[3];
But after such fix onnx2tensorrt convertion crashes with error about wrong dimensions in graph node
"""
[12/08/2022-01:44:30] [TRT] [I] MatMul_1915: broadcasting input1 to make tensors conform, dims(input0)=[20,144,1536][NONE] dims(input1)=[1,1536,1536][NONE].
[12/08/2022-01:44:30] [TRT] [E] [graphShapeAnalyzer.cpp::analyzeShapes::1285] Error Code 4: Miscellaneous (IElementWiseLayer Add_2075: broadcast dimensions must be conformable)
Traceback (most recent call last):
File "tools/onnx2tensorrt.py", line 73, in
@tak-ho-raspect, @hanrui1sensetime , I've found the discrepancy between
rewrite_outputs
(in test code upper) disappears if here one replaces index from 1 (taking offsets shape as output shape) to 0 (taking input shape as output shape):
ret.d[2] = inputs[0].d[2];
ret.d[3] = inputs[0].d[3];
@vedrusss I remember I changed that before but not working, the output shape and value will be the same as onnx version after changing this?
Yes, the output shape and value after that becomes the same as onnx.
But conversion from onnx 2 tensorrt doesn't work after that. And previously (before fix) converted model engine produces same incorrect detections (with low score and wrong boxes).
So, such "fix" just solves the issue "different outputs from onnx vs TRT" but not the subject of the current issue.
As wrote @hanrui1sensetime " https://github.com/open-mmlab/mmdeploy/pull/1493 can fix it. Thank to @grimoire" I tested it and found:
I'll move further with bounding boxes. We are pretty close to the solution!
Great!! Thanks to @hanrui1sensetime @grimoire @vedrusss previously I try others models and transform to trt always makes the bounding boxes off a bit, when you said the "bounding boxes are not same", do you mean have a lot different? Thanks!
The boxes are completely different. While scores are pretty close or even same. I believe the issue is on my side because I wrote my sample code without torch pre/post processing (need that to implement later the detector in C++). I mean I made image pre-processing (scaling and padding myself, without torch engine). As a result now I need to make detections post-processing (scale and shift). In original detector implementation all that is done under the hood of torch.
@vedrusss thanks, I will try on my side too probably next week.
I've done that. Just needed to rescale detected boxes to original image shape. Now I can prove: transformed to TRT dyhead model works fine, boxes and scores are same as obtained from original model. There are some small discrepancies for some detections, I believe they are due to small numerical discrepancies appeared during model transformation and this is normal. Thanks to @hanrui1sensetime @grimoire
BTW, currently I've tested only FP32 mode. Gonna do same test for FP16.
Just try with the master branch codes, the bounding box seems to be off just a little bit. And the score is much better now, test with 2 cases:
0.6751 (mmdetection inference directly) -> 0.6026 (tensorrt) 0.5015 (mmdetection inference directly) -> 0.4768 (tensorrt)
Thanks!!
@vedrusss Hello! How you solve the problem: Error Code 4: Miscellaneous (IElementWiseLayer Add_2075: broadcast dimensions must be conformable)
? I have the same issue.
I am using Dyhead to train an image detection model: https://github.com/open-mmlab/mmdetection/blob/master/configs/dyhead/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco.py
Using GPU docker, convert to tensorrt with tools/deploy.py success:
python3 tools/deploy.py /workdir/detection_onnx_static_1024x1024.py /workdir/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x.py /workdir/latest.pth /workdir/test-deploy-img-1024.jpg --device cuda --dump-info
Although the conversion have a lot warning like below:
I know after conversion the result is not exactly the same so I am okey with bounding box value difference(although its off quite a bit too), but the score is kind of drop too much! Below is the trt engine result (x1,y1,x2,y2,score)
[8.8506012, 358.41714, 149.80162, 495.56137, 0.081301391]
And below is the original predict with mmdet, the bbox have been round down[0, 328, 165, 526, 0.53286]
Here is my env with python3 tools/check_env.py
And here is my deploy config (/workdir/detection_onnx_static_1024x1024.py):
I also try with the default tensorrt version 8.2.x but no success
Could someone please help? Thanks a lot!!