NVIDIA / TensorRT

NVIDIA® TensorRT™ is an SDK for high-performance deep learning inference on NVIDIA GPUs. This repository contains the open source components of TensorRT.
https://developer.nvidia.com/tensorrt
Apache License 2.0
10.84k stars 2.14k forks source link

if_conditional() is time-consuming. #4059

Open Parsifal133 opened 3 months ago

Parsifal133 commented 3 months ago

Hello everyone!

I am using TensorRT 8.2 and the Python API to build a YOLOv5 model with multiple branches.

Specifically, each convolutional layer has multiple branches (but only one branch is executed during each inference), so I am using nested network.add_if_conditional().

Fortunately, I achieved the functionality I wanted, but the exported engine file is quite large (which is not the most important issue). However, the actual inference time increases as the number of branches increases.

This is the code for using nested if_conditional() for the YOLOv5 output heads. Since the number of branches is often more than two, nested if_conditional() is needed.

def get_yolo_head(bottleneck_csp17, bottleneck_csp20, bottleneck_csp23, weight_map, network, task_id, head_num=3):

    head_out = []
    head_in = [bottleneck_csp17, bottleneck_csp20, bottleneck_csp23]
    max = 255
    for head in range(head_num):
        det0_list = []  # multi-branch outputs
        det0_if_layer = []  # multi-branch if-condition layers
        for task in range(TOTAL_TASK - 1):
            if_conditional_layer = network.add_if_conditional()
            # set input
            cur_input = if_conditional_layer.add_input(head_in[head]).get_output(0)
            # set condition
            if_conditional_layer.set_condition(task_id[task + 1])

            det0 = network.add_convolution_nd(cur_input,
                                              3 * (CLASS_NUM[task + 1] + 5),
                                              trt.DimsHW(1, 1),
                                              kernel=weight_map[
                                                  "model.24." + str(task+1) + ".m." + str(head) + ".weight"],
                                              bias=weight_map["model.24." + str(task+1) + ".m." + str(head) + ".bias"])

            # zero padding to make the output shapes consistent
            env = reshape_det(network, det0.get_output(0), max)

            det0_list.append(env)
            det0_if_layer.append(if_conditional_layer)

        det0_base = network.add_convolution_nd(cur_input,
                                               3 * (CLASS_NUM[0] + 5),
                                               trt.DimsHW(1, 1),
                                               kernel=weight_map["model.24." + str(0) + ".m." + str(head) + ".weight"],
                                               bias=weight_map["model.24." + str(0) + ".m." + str(head) + ".bias"])

        for task in range(TOTAL_TASK - 1):
            c_l = det0_if_layer[task]
            if task == 0:
                det0 = c_l.add_output(det0_list[task], det0_base.get_output(0)).get_output(0)
            else:
                det0 = c_l.add_output(det0_list[task], det0).get_output(0)
        head_out.append(det0)

    return head_out[0], head_out[1], head_out[2]

I would like to know if there is a better way to avoid the increase in inference time.

Any possible suggestions would be greatly appreciated!

lix19937 commented 3 months ago

It seems a multi-task/head model. A slightly more complex approach is to split the model into backbone + num_heads, runtime dyanmic select the head module.

Parsifal133 commented 3 months ago

It seems a multi-task/head model. A slightly more complex approach is to split the model into backbone + num_heads, runtime dyanmic select the head module.

Yes, this is a continual learning multi-task model. In fact, each convolutional layer of the model have multiple branches, not just the head. Here only the head code with nested if-conditional layer is shown.

I will try not to introduce multiple branches in each convolutional layer,but it will reduce the effectiveness of continual learning.

lix19937 commented 3 months ago

Can you draw the data-flow of model ?

Parsifal133 commented 3 months ago

Hello, @lix19937 ! I have illustrated a general convolutional layer and my multi-task convolutional layer, as shown in the figure. It is evident that the conventional convolutional layer on the left consists of a convolution, batch normalization (BN) layer, and activation layer, whereas my convolutional layer incorporates N branches. However, during inference, only one branch is executed at a time. Specifically, when the task_id==1, the first branch is executed, and its output is added to the main branch’s output before being passed into the activation function. 1723791341138

I implemented this logic using the network.add_if_conditional() in the Python API, and the resulting engine produces correct inference results. The only issue is that the engine’s inference is relatively slow and occupies a considerable amount of space. Additionally, as the number of branches increases, the inference time further increases. The specific experimental results are presented in the table below. I suspect that the introduction of numerous if_conditional is causing this issue. Therefore, I humbly seek advice from the community and from you on whether there are better solutions to this problem.

6d7454b046b04290cc717a482ded2f7

akhilg-nv commented 3 months ago

Hi @Parsifal133 could you try running your model on latest version of TRT? We expect improved performance on latest version.

lix19937 commented 3 months ago

Additionally, as the number of branches increases, the inference time further increases.

YES, it means that there are multiple branches that need to be computed and they are serial .

The only issue is that the engine’s inference is relatively slow and occupies a considerable amount of space.

What precision is used ? (fp32/fp16/int8)