triton-inference-server / server

The Triton Inference Server provides an optimized cloud and edge inferencing solution.
https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html
BSD 3-Clause "New" or "Revised" License
8.25k stars 1.47k forks source link

Triton inference is 2 times slower than non triton inference for me #5229

Closed ArgoHA closed 1 year ago

ArgoHA commented 1 year ago

Description Hi! I get around 2 times less fps using triton server in comparison to non triton inference. Here is what I've got: 1) I took yolov5s pretrained weights from https://github.com/ultralytics/yolov5 2) I exported them to .engine (.plan) 3) I run detect.py from yolov5 repo and got 10 fps on a test video (on jetson nano) 4) I created a yolov5_GRPC pipeline, put model.plan and run triton server. Then I run me pipeline and get around 5 fps on the same test video.

Can anyone give me a hint, do I have a bug?

In both inferences I used fp32. I also inspected jtop and I am pretty sure that with detec.py I get longer 100% usage of the gpu, when with triton I more often see 0% usage in between of 100's.

Triton Information tritonserver2.19.0-jetpack4.6.1

Are you using the Triton container or did you build it yourself? Build

To Reproduce

main.py

import time
import cv2
from pathlib import Path

from src.yolov5_grpc import Yolov5_grpc
from src.utils import fps_counter

class Video_stream:
    def __init__(self, src):
        self.cap = cv2.VideoCapture(src)

    def read(self):
        ret, frame = self.cap.read()
        if ret:
            return frame

class Pipeline:
    def __init__(self, src: str, detector_thres: float = 0.5):
        self.detector_thres = detector_thres
        self.root_path = Path(__file__).parent.absolute()
        self.images_path_save = self.root_path / 'images'

        self.camera = Video_stream(src)
        self.detector = Yolov5_grpc(conf_thresh=detector_thres)
        self.create_images_folder()
        self.idx = 0
        self.running = True

    def create_images_folder(self):
        Path(self.images_path_save).mkdir(parents=True, exist_ok=True)

    def save_output(self, pred_frame):
        output_path = (self.images_path_save / f'image_{self.idx}').with_suffix('.jpeg')
        cv2.imwrite(str(output_path), pred_frame)

    @fps_counter
    def _runner(self):
        frame = self.camera.read()
        if frame is None:
            self.running = False
            return

        boxes, pred_frame, _ = self.detector.get_boxes_debug(frame)
        # if boxes:
        #     self.save_output(pred_frame)

        self.idx += 1

    def run(self):
        while self.running:
            self._runner()

def main():
    src = '/home/argo/test_vid.mp4'
    detector_thres = 0.5

    Pipeline(src, detector_thres).run()

if __name__ == '__main__':
    main()

config.pbtxt

name: "yolov5"
platform: "tensorrt_plan"
max_batch_size: 1
input [
  {
    name: "images"
    data_type: TYPE_FP32
    dims: [ 3, 640, 640 ]
  }
]
output [
  {
    name: "output0"
    data_type: TYPE_FP32
    dims: [ 25200, 85 ]
  }
]

utils.py

import time
import cv2
import torch
import numpy as np

def fps_counter(func):
    def wrapper_function(*args, **kwargs):
        start_timer = time.perf_counter()
        func_res = func(*args,  **kwargs)
        end_timer = time.perf_counter()

        fps = round(1 / (end_timer - start_timer), 1)
        print(f'FPS: {fps}')
        return func_res
    return wrapper_function

def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
    y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
    y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
    return y

def clip_boxes(boxes, shape):
    # Clip boxes (xyxy) to image shape (height, width)
    if isinstance(boxes, torch.Tensor):  # faster individually
        boxes[..., 0].clamp_(0, shape[1])  # x1
        boxes[..., 1].clamp_(0, shape[0])  # y1
        boxes[..., 2].clamp_(0, shape[1])  # x2
        boxes[..., 3].clamp_(0, shape[0])  # y2
    else:  # np.array (faster grouped)
        boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1])  # x1, x2
        boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0])  # y1, y2

def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
    # Rescale boxes (xyxy) from img1_shape to img0_shape
    if ratio_pad is None:  # calculate from img0_shape
        gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
        pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
    else:
        gain = ratio_pad[0][0]
        pad = ratio_pad[1]

    boxes[..., [0, 2]] -= pad[0]  # x padding
    boxes[..., [1, 3]] -= pad[1]  # y padding
    boxes[..., :4] /= gain
    clip_boxes(boxes, img0_shape)
    return boxes

def bbox_iou(box1, box2, x1y1x2y2=True):
    """
    description: compute the IoU of two bounding boxes
    param:
        box1: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
        box2: A box coordinate (can be (x1, y1, x2, y2) or (x, y, w, h))
        x1y1x2y2: select the coordinate format
    return:
        iou: computed iou
    """
    if not x1y1x2y2:
        # Transform from center and width to exact coordinates
        b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
        b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
    else:
        # Get the coordinates of bounding boxes
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]

    # Get the coordinates of the intersection rectangle
    inter_rect_x1 = np.maximum(b1_x1, b2_x1)
    inter_rect_y1 = np.maximum(b1_y1, b2_y1)
    inter_rect_x2 = np.minimum(b1_x2, b2_x2)
    inter_rect_y2 = np.minimum(b1_y2, b2_y2)
    # Intersection area
    inter_area = np.clip(inter_rect_x2 - inter_rect_x1 + 1, 0, None) * \
                    np.clip(inter_rect_y2 - inter_rect_y1 + 1, 0, None)
    # Union Area
    b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
    b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)

    iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)

    return iou

def draw_boxes(image, coords, scores):
    box_color = (51, 51, 255)
    font_color = (255, 255, 255)

    line_width = max(round(sum(image.shape) / 2 * 0.0025), 2)
    font_thickness = max(line_width - 1, 1)
    draw_image = image.copy()

    if coords and len(coords):
        for idx, tb in enumerate(coords):
            if tb[0] >= tb[2] or tb[1] >= tb[3]:
                continue
            obj_coords = list(map(int, tb[:4]))

            # bbox
            p1, p2 = (int(obj_coords[0]), int(obj_coords[1])), (int(obj_coords[2]), int(obj_coords[3]))
            cv2.rectangle(draw_image, p1, p2, box_color, thickness=line_width, lineType=cv2.LINE_AA)

            # Conf level
            label = str(int(round(scores[idx], 2) * 100)) + '%'
            w, h = cv2.getTextSize(label, 0, fontScale=2, thickness=3)[0]  # text width, height
            outside = obj_coords[1] - h - 3 >= 0  # label fits outside box

            w, h = cv2.getTextSize(label, 0, fontScale=line_width / 3, thickness=font_thickness)[0]  # text width, height
            outside = p1[1] - h >= 3
            p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3

            cv2.rectangle(draw_image, p1, p2, box_color, -1, cv2.LINE_AA)  # filled
            cv2.putText(draw_image,
                        label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
                        0,
                        line_width / 3,
                        font_color,
                        thickness=font_thickness,
                        lineType=cv2.LINE_AA)
    return draw_image

def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
    # Resize and pad image while meeting stride-multiple constraints
    shape = im.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)

    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better val mAP)
        r = min(r, 1.0)

    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios

    dw /= 2  # divide padding into 2 sides
    dh /= 2

    if shape[::-1] != new_unpad:  # resize
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return im, ratio, (dw, dh)

yolov5_grpc.py

# This script is based on different grpc examples for triton server
import tritonclient.grpc as grpcclient
from typing import List
import numpy as np
import sys

from src.utils import draw_boxes, letterbox, xywh2xyxy, bbox_iou, scale_boxes, fps_counter

IOU_THRESHOLD = 0.45

class Yolov5_grpc():
    def __init__(self,
                 url="localhost:8001",
                 model_name="yolov5",
                 input_width=640,
                 input_height=640,
                 model_version="",
                 verbose=False, conf_thresh=0.8) -> None:
        super(Yolov5_grpc).__init__()
        self.model_name = model_name

        self.input_width = input_width
        self.input_height = input_height
        self.batch_size = 1
        self.conf_thresh = conf_thresh
        self.input_shape = [self.batch_size, 3, self.input_height, self.input_width]
        self.input_name = 'images'
        self.output_name = 'output0'
        self.output_size = 25200
        self.triton_client = None

        self.fp = 'FP32'

        if '16' in self.fp:
            self.np_dtype = np.float16
        else:
            self.np_dtype = np.float32

        self.init_triton_client(url)
        self.test_predict()

    def init_triton_client(self, url):
        try:
            triton_client = grpcclient.InferenceServerClient(
                url=url,
                verbose=False,
                ssl=False,
            )
        except Exception as e:
            print("channel creation failed: " + str(e))
            sys.exit()
        self.triton_client = triton_client

    def test_predict(self):
        input_images = np.zeros((*self.input_shape,), dtype=self.np_dtype)
        _ = self.predict(input_images)

    def predict(self, input_images):
        inputs = []
        outputs = []

        inputs.append(grpcclient.InferInput(self.input_name, [*input_images.shape], self.fp))
        # Initialize the data
        inputs[-1].set_data_from_numpy(input_images)
        outputs.append(grpcclient.InferRequestedOutput(self.output_name))

        # Test with outputs
        results = self.triton_client.infer(
            model_name=self.model_name,
            inputs=inputs,
            outputs=outputs)

        # Get the output arrays from the results
        return results.as_numpy(self.output_name)

    def non_max_suppression(self, prediction, origin_h, origin_w, conf_thres=0.5, nms_thres=0.4):
        """
        description: Removes detections with lower object confidence score than 'conf_thres' and performs
        Non-Maximum Suppression to further filter detections.
        param:
            prediction: detections, (x1, y1, x2, y2, conf, cls_id)
            origin_h: original image height
            origin_w: original image width
            conf_thres: a confidence threshold to filter detections
            nms_thres: a iou threshold to filter detections
        return:
            boxes: output after nms with the shape (x1, y1, x2, y2, conf, cls_id)
        """
        # Get the boxes that score > conf_thresh
        boxes = prediction[prediction[:, 4] >= conf_thres]

        # Trandform bbox from [center_x, center_y, w, h] to [x1, y1, x2, y2]
        boxes[:, :4] = xywh2xyxy(boxes[:, :4])
        # clip the coordinates
        boxes[:, 0] = np.clip(boxes[:, 0], 0, origin_w - 1)
        boxes[:, 2] = np.clip(boxes[:, 2], 0, origin_w - 1)
        boxes[:, 1] = np.clip(boxes[:, 1], 0, origin_h - 1)
        boxes[:, 3] = np.clip(boxes[:, 3], 0, origin_h - 1)
        # Object confidence
        confs = boxes[:, 4]

        # Sort by the confs
        boxes = boxes[np.argsort(-confs)]
        # Perform non-maximum suppression
        keep_boxes = []
        while boxes.shape[0]:
            large_overlap = bbox_iou(np.expand_dims(boxes[0, :4], 0), boxes[:, :4]) > nms_thres
            label_match = np.round(boxes[0, -1]) == np.round(boxes[:, -1])
            # Indices of boxes with lower confidence scores, large IOUs and matching labels
            invalid = large_overlap & label_match
            keep_boxes += [boxes[0]]
            boxes = boxes[~invalid]
        boxes = np.stack(keep_boxes, 0) if len(keep_boxes) else np.array([])
        return boxes

    def post_process(self, output, origin_h, origin_w):
        """
        description: postprocess the prediction
        param:
            output:     A numpy likes [num_boxes,cx,cy,w,h,conf,cls_id, cx,cy,w,h,conf,cls_id, ...]
            origin_h:   height of original image
            origin_w:   width of original image
        return:
            result_boxes: finally boxes, a boxes numpy, each row is a box [x1, y1, x2, y2]
            result_scores: finally scores, a numpy, each element is the score correspoing to box
            result_classid: finally classid, a numpy, each element is the classid correspoing to box
        """
        # Do nms
        boxes = self.non_max_suppression(output, origin_h, origin_w, conf_thres=self.conf_thresh, nms_thres=IOU_THRESHOLD)
        result_boxes = boxes[:, :4] if len(boxes) else np.array([])
        result_scores = boxes[:, 4] if len(boxes) else np.array([])
        result_classid = boxes[:, 5] if len(boxes) else np.array([])

        # rescale boxes to original image size from processing size (640x640 -> 1920x1080)
        result_boxes = scale_boxes((self.input_height, self.input_width), result_boxes, (origin_h, origin_w))
        return result_boxes, result_scores, result_classid

    def preprocess(self, img, stride):
        img = letterbox(img, max(self.input_width, self.input_height), stride=stride, auto=False)[0]
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB
        img = np.ascontiguousarray(img)
        img = img.astype(self.np_dtype)
        img = img / 255.0  # 0 - 255 to 0.0 - 1.0
        img = img.reshape([1, *img.shape])
        return img

    def postprocess(self, host_outputs, batch_origin_h, batch_origin_w, min_accuracy=0.5):
        output = host_outputs[0]
        # Do postprocess
        answer = []
        valid_scores = []
        for i in range(self.batch_size):
            result_boxes, result_scores, result_classid = self.post_process(
                output[i * self.output_size: (i + 1) * self.output_size], batch_origin_h, batch_origin_w
            )
            for box, score in zip(result_boxes, result_scores):
                if score > min_accuracy:
                    answer.append(box)
                    valid_scores.append(score)
        return answer, valid_scores

    def grpc_detect(self, image: np.ndarray, stride: int = 32, min_accuracy: float = 0.5) -> List:
        processed_image = self.preprocess(image, stride)
        pred = self.predict(processed_image)
        boxes, scores = self.postprocess(pred, image.shape[0], image.shape[1])
        return boxes, scores

    def get_boxes_debug(self, image):
        boxes, scores = self.grpc_detect(image)
        debug_image = draw_boxes(image, boxes, scores)
        return boxes, debug_image, scores
tanmayv25 commented 1 year ago

@ArgoHA Looks like you are creating the grpcclient.InferInput with every predict call. This means a new protobuf object is created with every inference run. Can you create a single grpc.InferInput object and then call set_data_from_numpy for each inference run?

That being said, the requests are still going through grpc endpoint which will include message marshalling/unmarshalling and some communication costs. So, the single stream performance will not match with the standalone application. You can feed multiple streams and Triton will effectively scale the inferences on the available model instances for better throughput.
You can also query inference statistics from the server using this API: https://github.com/triton-inference-server/client/blob/main/src/python/library/tritonclient/grpc/__init__.py#L712

More information about inference statistics here: https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_statistics.md

ArgoHA commented 1 year ago

@tanmayv25 Thanks for the answer I tried to use time.perf_counter() to measure the speed of inputs.append(grpcclient.InferInput(self.input_name, [*input_images.shape], self.fp)) and got 0.00035297730937600136 seconds, so it doesn't really matter.

But I also tried to create an object in init and then use set_data_from_numpy, but my fps did't change.

jbkyang-nvi commented 1 year ago

Hi @ArgoHA can you share your inference statistics here? If you have a reproducer with your current setup it would be helpful for us to reproduce if it's a bug

ArgoHA commented 1 year ago

@jbkyang-nvi Do you mean something specific by a reproducer? I can try to create a iso backup of the whole system. Or I can give you entire code of the project, assuming that you have installed the same triton server version

innerNULL commented 1 year ago

Any updates? I got the same issue T T, also 2 times slower

tanmayv25 commented 1 year ago

@ArgoHA I think @jbkyang-nvi is asking for output from statistics. Output from get_inference_statistics API call. As explained above, Triton clients will spend some time sending tensor bytes across. The output from the inference_statistics will help us understand what all parts request is spending time in. Based on that we can provide some suggestion to avoid extra data copies that Triton pipeline currently incurs. Some suggestion that might help here is to use shared memory to send data from client process to server. More on this here: https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_shared_memory.md You should see performance improvement using shared memory. We have some example using system shared memory with gRPC client here: https://github.com/triton-inference-server/client/blob/main/src/python/examples/simple_grpc_shm_client.py

Tabrizian commented 1 year ago

Closing due to in-activity.