Tau-J / rtmlib

RTMPose series (RTMPose, DWPose, RTMO, RTMW) without mmcv, mmpose, mmdet etc.
Apache License 2.0
206 stars 24 forks source link

tensorrt的支持 #12

Open Egrt opened 8 months ago

Egrt commented 8 months ago

请问能否增加对tensorrt的支持?

Egrt commented 8 months ago

[02/07/2024-21:51:33] [TRT] [E] ModelImporter.cpp:729: --- End node --- [02/07/2024-21:51:33] [TRT] [E] ModelImporter.cpp:732: ERROR: ModelImporter.cpp:168 In function parseGraph: [6] Invalid Node - TopK_587 This version of TensorRT only supports input K as an initializer. Try applying constant folding on the model using Polygraphy: https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/examples/cli/surgeon/02_folding_constants Traceback (most recent call last): File "d:/Notebook/rtmlib/rtm.py", line 41, in <module> build_model(rtmdet_onnx_model) File "d:/Notebook/rtmlib/rtm.py", line 16, in build_model build_engine(onnx_file_path, engine_file_path, True) File "d:/Notebook/rtmlib/rtm.py", line 28, in build_engine raise RuntimeError(f'failed to load ONNX file: {onnx_file_path}') RuntimeError: failed to load ONNX file: rtmdet_nano_8xb32-300e_hand-267f9c8f.onnx 我的环境为: TensorRT 8.5.1.7 当我将onnx模型转换为engine模型时候出现了报错,该如何解决

Tau-J commented 8 months ago

Hi @Egrt, RTMPose 的TensorRT转换流程请参考官方文档。关于TensorRT推理的支持,在计划之中,但由于版本对齐相对困难,暂时不会在短期内完成

chenscottus commented 8 months ago

Please update to TensorRT 8.6.x

Egrt commented 7 months ago

Please update to TensorRT 8.6.x

Thanks, it worked.

Egrt commented 7 months ago

实现了rtmdet的tensorrt加速,转换onnx模型时必须转换为静态模型,仅供参考 @Tau-J

from typing import List, Tuple
import os
import numpy as np
import tensorrt as trt
import cv2
import time

def build_model(onnx_file_path):
    engine_file_path = onnx_file_path.replace('.onnx', '.engine')

    if not os.path.exists(engine_file_path):
        print('模型制作中,第一次等待时间较长, 完成后会有文字提示')
        build_engine(onnx_file_path, engine_file_path, True)

def build_engine(onnx_file_path, engine_file_path, half=True):
    """Takes an ONNX file and creates a TensorRT engine to run inference with"""
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    config = builder.create_builder_config()
    config.max_workspace_size = 4 * 1 << 30
    flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(flag)
    parser = trt.OnnxParser(network, logger)
    if not parser.parse_from_file(str(onnx_file_path)):
        raise RuntimeError(f'failed to load ONNX file: {onnx_file_path}')
    half &= builder.platform_has_fast_fp16
    if half:
        config.set_flag(trt.BuilderFlag.FP16)
    with builder.build_engine(network, config) as engine, open(engine_file_path, 'wb') as t:
        t.write(engine.serialize())
    return engine_file_path

def draw_bbox(img, bboxes, color=(0, 255, 0)):
    for bbox in bboxes:
        img = cv2.rectangle(img, (int(bbox[0]), int(bbox[1])),
                            (int(bbox[2]), int(bbox[3])), color, 2)
    return img

def nms(boxes, scores, nms_thr):
    """Single class NMS implemented in Numpy."""
    x1 = boxes[:, 0]
    y1 = boxes[:, 1]
    x2 = boxes[:, 2]
    y2 = boxes[:, 3]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    keep = []
    while order.size > 0:
        i = order[0]
        keep.append(i)
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        ovr = inter / (areas[i] + areas[order[1:]] - inter)

        inds = np.where(ovr <= nms_thr)[0]
        order = order[inds + 1]

    return keep

def multiclass_nms(boxes, scores, nms_thr, score_thr):
    """Multiclass NMS implemented in Numpy.

    Class-aware version.
    """
    final_dets = []
    num_classes = scores.shape[1]
    for cls_ind in range(num_classes):
        cls_scores = scores[:, cls_ind]
        valid_score_mask = cls_scores > score_thr
        if valid_score_mask.sum() == 0:
            continue
        else:
            valid_scores = cls_scores[valid_score_mask]
            valid_boxes = boxes[valid_score_mask]
            keep = nms(valid_boxes, valid_scores, nms_thr)
            if len(keep) > 0:
                cls_inds = np.ones((len(keep), 1)) * cls_ind
                dets = np.concatenate(
                    [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1)
                final_dets.append(dets)
    if len(final_dets) == 0:
        return None
    return np.concatenate(final_dets, 0)

RTMLIB_SETTINGS = {
    'opencv': {
        'cpu': (cv2.dnn.DNN_BACKEND_OPENCV, cv2.dnn.DNN_TARGET_CPU),

        # You need to manually build OpenCV through cmake
        'cuda': (cv2.dnn.DNN_BACKEND_CUDA, cv2.dnn.DNN_TARGET_CUDA)
    },
    'onnxruntime': {
        'cpu': 'CPUExecutionProvider',
        'cuda': 'CUDAExecutionProvider'
    },
}

class BaseTool():

    def __init__(self,
                 onnx_model: str = None,
                 model_input_size: tuple = None,
                 mean: tuple = None,
                 std: tuple = None,
                 nms_thr: float = 0.5,
                 score_thr: float =0.3,
                 backend: str = 'tensorrt',
                 device: str = 'cuda'):

        if backend == 'opencv':
            try:
                providers = RTMLIB_SETTINGS[backend][device]

                session = cv2.dnn.readNetFromONNX(onnx_model)
                session.setPreferableBackend(providers[0])
                session.setPreferableTarget(providers[1])
                self.session = session
            except Exception:
                raise RuntimeError(
                    'This model is not supported by OpenCV'
                    ' backend, please use `pip install'
                    ' onnxruntime` or `pip install'
                    ' onnxruntime-gpu` to install onnxruntime'
                    ' backend. Then specify `backend=onnxruntime`.')  # noqa

        elif backend == 'onnxruntime':
            import onnxruntime as ort
            providers = RTMLIB_SETTINGS[backend][device]

            self.session = ort.InferenceSession(path_or_bytes=onnx_model,
                                                providers=[providers])
        elif backend == 'tensorrt':
            import tensorrt as trt
            import pycuda.driver as cuda
            import pycuda.autoinit

            engine_path = onnx_model.replace('.onnx', '.engine')
            logger = trt.Logger(trt.Logger.WARNING)
            logger.min_severity = trt.Logger.Severity.ERROR
            runtime = trt.Runtime(logger)
            trt.init_libnvinfer_plugins(logger,'') # initialize TensorRT plugins
            with open(engine_path, "rb") as f:
                serialized_engine = f.read()
            engine = runtime.deserialize_cuda_engine(serialized_engine)
            self.imgsz = engine.get_binding_shape(0)[2:]  # get the read shape of model, in case user input it wrong
            self.context = engine.create_execution_context()
            self.inputs, self.outputs, self.bindings = [], [], []
            self.stream = cuda.Stream()
            for binding in engine:
                size = trt.volume(engine.get_binding_shape(binding))
                dims = engine.get_binding_shape(binding)
                if dims[1] < 0:
                    size *= -1
                dtype = trt.nptype(engine.get_binding_dtype(binding))
                host_mem = cuda.pagelocked_empty(size, dtype)
                device_mem = cuda.mem_alloc(host_mem.nbytes)
                self.bindings.append(int(device_mem))
                if engine.binding_is_input(binding):
                    self.inputs.append({'host': host_mem, 'device': device_mem})
                else:
                    self.outputs.append({'host': host_mem, 'device': device_mem})
        else:
            raise NotImplementedError

        print(f'load {onnx_model} with {backend} backend')

        self.onnx_model = onnx_model
        self.model_input_size = model_input_size
        self.mean = mean
        self.std = std
        self.nms_thr = nms_thr
        self.score_thr = score_thr
        self.backend = backend
        self.device = device

    def inference(self, img: np.ndarray):
        """Inference model.

        Args:
            img (np.ndarray): Input image in shape.

        Returns:
            outputs (np.ndarray): Output of RTMPose model.
        """
        # build input to (1, 3, H, W)
        img = img.transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        input = img[None, :, :, :]

        # run model
        if self.backend == 'onnxruntime':
            sess_input = {self.session.get_inputs()[0].name: input}
            sess_output = []
            for out in self.session.get_outputs():
                sess_output.append(out.name)

            outputs = self.session.run(sess_output, sess_input)
        elif self.backend == 'tensorrt':
            import pycuda.driver as cuda
            self.inputs[0]['host'] = np.ravel(img)
            # transfer data to the gpu
            for inp in self.inputs:
                cuda.memcpy_htod_async(inp['device'], inp['host'], self.stream)
            # run inference
            self.context.execute_async_v2(
                bindings=self.bindings,
                stream_handle=self.stream.handle)
            # fetch outputs from gpu
            for out in self.outputs:
                cuda.memcpy_dtoh_async(out['host'], out['device'], self.stream)
            # synchronize stream
            self.stream.synchronize()

            outputs = [out['host'] for out in self.outputs]
            outputs = np.array(outputs).reshape(1, 1, 8400, 6)
        return outputs

class RTMDet(BaseTool):

    def __init__(self,
                 onnx_model: str,
                 model_input_size: tuple = (640, 640),
                 mean: tuple = (103.5300, 116.2800, 123.6750),
                 std: tuple = (57.3750, 57.1200, 58.3950),
                 nms_thr: float = 0.5,
                 score_thr: float =0.3,
                 backend: str = 'tensorrt',
                 device: str = 'cpu'):
        super().__init__(onnx_model,
                         model_input_size,
                         mean,
                         std,
                         nms_thr=nms_thr,
                         score_thr=score_thr,
                         backend=backend,
                         device=device)

    def __call__(self, image: np.ndarray):
        image, ratio = self.preprocess(image)
        outputs = self.inference(image)[0]
        results = self.postprocess(outputs, ratio)
        return results

    def preprocess(self, img: np.ndarray):
        """Do preprocessing for RTMPose model inference.

        Args:
            img (np.ndarray): Input image in shape.

        Returns:
            tuple:
            - resized_img (np.ndarray): Preprocessed image.
            - center (np.ndarray): Center of image.
            - scale (np.ndarray): Scale of image.
        """
        if len(img.shape) == 3:
            padded_img = np.ones(
                (self.model_input_size[0], self.model_input_size[1], 3),
                dtype=np.uint8) * 114
        else:
            padded_img = np.ones(self.model_input_size, dtype=np.uint8) * 114

        ratio = min(self.model_input_size[0] / img.shape[0],
                    self.model_input_size[1] / img.shape[1])
        resized_img = cv2.resize(
            img,
            (int(img.shape[1] * ratio), int(img.shape[0] * ratio)),
            interpolation=cv2.INTER_LINEAR,
        ).astype(np.uint8)
        padded_shape = (int(img.shape[0] * ratio), int(img.shape[1] * ratio))
        padded_img[:padded_shape[0], :padded_shape[1]] = resized_img

        # normalize image
        if self.mean is not None:
            self.mean = np.array(self.mean)
            self.std = np.array(self.std)
            padded_img = (padded_img - self.mean) / self.std

        return padded_img, ratio

    def postprocess(
        self,
        outputs: List[np.ndarray],
        ratio: float = 1.,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Do postprocessing for RTMDet model inference.

        Args:
            outputs (List[np.ndarray]): Outputs of RTMDet model.
            ratio (float): Ratio of preprocessing.

        Returns:
            tuple:
            - final_boxes (np.ndarray): Final bounding boxes.
            - final_scores (np.ndarray): Final scores.
        """

        if outputs.shape[-1] == 4:
            # onnx without nms module

            grids = []
            expanded_strides = []
            strides = [8, 16, 32]

            hsizes = [self.model_input_size[0] // stride for stride in strides]
            wsizes = [self.model_input_size[1] // stride for stride in strides]

            for hsize, wsize, stride in zip(hsizes, wsizes, strides):
                xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
                grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
                grids.append(grid)
                shape = grid.shape[:2]
                expanded_strides.append(np.full((*shape, 1), stride))

            grids = np.concatenate(grids, 1)
            expanded_strides = np.concatenate(expanded_strides, 1)
            outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
            outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides

            predictions = outputs[0]
            boxes = predictions[:, :4]
            scores = predictions[:, 4:5] * predictions[:, 5:]

            boxes_xyxy = np.ones_like(boxes)
            boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.
            boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.
            boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.
            boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.
            boxes_xyxy /= ratio
            dets = multiclass_nms(boxes_xyxy,
                                  scores,
                                  nms_thr=self.nms_thr,
                                  score_thr=self.score_thr)
            if dets is not None:
                pack_dets = (dets[:, :4], dets[:, 4], dets[:, 5])
                final_boxes, final_scores, final_cls_inds = pack_dets
                isscore = final_scores > self.score_thr
                iscat = final_cls_inds == 0
                isbbox = [i and j for (i, j) in zip(isscore, iscat)]
                final_boxes = final_boxes[isbbox]

        elif outputs.shape[-1] == 5:
            # onnx contains nms module

            pack_dets = (outputs[0, :, :4], outputs[0, :, 4])
            final_boxes, final_scores = pack_dets
            final_boxes /= ratio
            isscore = final_scores > self.score_thr
            isbbox = [i for i in isscore]
            final_boxes = final_boxes[isbbox]

        elif outputs.shape[-1] == 6:
            # onnx static
            dets = multiclass_nms(outputs[0, :, :4], outputs[0, :, 4:6],
                                  nms_thr=self.nms_thr,
                                  score_thr=self.score_thr)
            pack_dets = (dets[:, :4], dets[:, 4], dets[:, 5])
            final_boxes, final_scores, final_cls_inds = pack_dets
            final_boxes /= ratio

        return final_boxes

if __name__=='__main__':
    mode = 'image'
    image_path = 'test.jpg'
    rtmdet_onnx_model = 'rtmdet_tiny_uav.onnx'
    rtmpose_onnx_model = ''
    build_model(rtmdet_onnx_model)

    rtmdet = RTMDet(onnx_model=rtmdet_onnx_model,
                    model_input_size=(640, 640),
                    backend='tensorrt',
                    device='cuda'
                    )
    if mode == 'video':
        cap = cv2.VideoCapture(0)

        frame_idx = 0

        while cap.isOpened():
            success, frame = cap.read()
            frame_idx += 1

            if not success:
                break
            s = time.time()
            bboxes = rtmdet(frame)
            # keypoints, scores = rtmpose(frame, bboxes=bboxes)
            det_time = time.time() - s
            print('det: ', det_time)

            img_show = frame.copy()

            img_show = draw_bbox(img_show, bboxes, (0, 255, 0))
            # img_show = draw_skeleton(img_show,
            #                      keypoints,
            #                      scores,
            #                      False,
            #                      kpt_thr=0.2,
            #                      line_width=3)
            img_show = cv2.resize(img_show, (960, 640))
            cv2.imshow('img', img_show)
            key = cv2.waitKey(25)  
            if key == ord('q'): 
                cap.release()     
                break
        cv2.destroyAllWindows()

    elif mode == 'image':
        from PIL import Image
        frame = cv2.imread(image_path)
        s = time.time()
        bboxes = rtmdet(frame)
        img_show = frame.copy()
        img_show = draw_bbox(img_show, bboxes, (0, 255, 0))
        det_time = time.time() - s
        print('det: ', det_time)
        image = Image.fromarray(img_show)
        image.show()
vieenrose commented 6 months ago

Alternatively, you can enable TensorrtExecutionProvider on ONNXruntime to use TensorRT as inference backend. Note that you may have to perform [shape inference]() on ONNX model first using symbolic_shape_infer.py to prepare your model. Also for TensorRT 8.2-8.4, build custom TensorRT Ops plug-in from MMdeploy (and load the plug-in to Tensorrt Execution Provider following usage) is also required.