triton-inference-server / server

The Triton Inference Server provides an optimized cloud and edge inferencing solution.
https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html
BSD 3-Clause "New" or "Revised" License
8.25k stars 1.47k forks source link

Issues regarding the performance of cudashm.set_shared_memory_region(shm_input_handle, [img]) #5887

Closed muqishan closed 1 year ago

muqishan commented 1 year ago

Triton version: 23.04 Model type: Python backend tritonclient[all] version:2.33.0 python Version: 3.8.16 Problem: cudashm.set_shared_memory_region(shm_input_handle, [img]) I understand this line of code as copying img into CUDA's shared memory. Before this, I have already applied for the address of the CUDA shared memory for output/input. My img is a uint8 type nd.array of 32360640*3. Through calculation, I know it occupies about 20mb of memory space. My PCI-E bus actual speed is (5.5G/S). Theoretically, it should consume <5ms, but I actually got 20ms and occasionally fluctuated to 40ms. I verified this fact through pytorch's tensor.to('cuda') and received <5ms feedback from PyTorch. Before this, I have created a handle and registered related information, but I don't quite understand why the time consumption here is? According to what I know, this is related to CPU usage, GPU usage, and PCI-E bus. If I want to get a response speed of <5ms, what should I do, whether to upgrade hardware or change the sending method?Looking forward to your reply, I will be very grateful. Here is some of my code:

def create_shared_memory_pool(self):
        for i in range(1, self.batch_size+1):
            for j in range(1,self.num_batch_size):
                num = str(i)+'_'+str(j)
                request_id = num
                shm_input1 = f'shm_input1_{num}'
                shm_output0 = f'shm_output0_{num}'
                shm_output1 = f'shm_output1_{num}'
                shm_output2 = f'shm_output2_{num}'
                input_byte_size = i  * 3 * 360 * 640
                output_byte_size_output0 = i  * 13300 * 4 * 4
                output_byte_size_output1 = i  * 13300 * 2 * 4
                output_byte_size_output2 = i  * 13300 * 10 * 4
                shm_input_handle = cudashm.create_shared_memory_region(shm_input1, input_byte_size, 0)
                shm_op0_handle = cudashm.create_shared_memory_region(shm_output0, output_byte_size_output0, 0)
                shm_op1_handle = cudashm.create_shared_memory_region(shm_output1, output_byte_size_output1, 0)
                shm_op2_handle = cudashm.create_shared_memory_region(shm_output2, output_byte_size_output2, 0)               
                self.triton_client.register_cuda_shared_memory(shm_input1, cudashm.get_raw_handle(shm_input_handle), 0, input_byte_size)
                self.triton_client.register_cuda_shared_memory(shm_output0, cudashm.get_raw_handle(shm_op0_handle), 0, output_byte_size_output0)
                self.triton_client.register_cuda_shared_memory(shm_output1, cudashm.get_raw_handle(shm_op1_handle), 0, output_byte_size_output1)
                self.triton_client.register_cuda_shared_memory(shm_output2, cudashm.get_raw_handle(shm_op2_handle), 0, output_byte_size_output2)
                self.shared_memory_pool[request_id] = {
                    'request_id': request_id,
                    'shm_input1': shm_input1,
                    'shm_output0': shm_output0,
                    'shm_output1': shm_output1,
                    'shm_output2': shm_output2,
                    'shm_input_handle': shm_input_handle,
                    'shm_op0_handle': shm_op0_handle,
                    'shm_op1_handle': shm_op1_handle,
                    'shm_op2_handle': shm_op2_handle,
                    'input_byte_size': input_byte_size,
                    'output_byte_size_output0': output_byte_size_output0,
                    'output_byte_size_output1': output_byte_size_output1,
                    'output_byte_size_output2': output_byte_size_output2,
                    'available': True
                }
def detec_triton_infer(self, img, request_id):
        memory_block = self.shared_memory_pool[request_id]
        shm_input_handle = memory_block['shm_input_handle']
        input_byte_size = memory_block['input_byte_size']
        output_byte_size_output0 = memory_block['output_byte_size_output0']
        output_byte_size_output1 = memory_block['output_byte_size_output1']
        output_byte_size_output2 = memory_block['output_byte_size_output2']
        shm_input1 = memory_block['shm_input1']
        shm_output0 = memory_block['shm_output0']
        shm_output1 = memory_block['shm_output1']
        shm_output2 = memory_block['shm_output2']
        now = time.time()
        print(img.dtype)
        **cudashm.set_shared_memory_region(shm_input_handle, [img])**
        print(int(round(time.time() * 1000)),' load date: cpu->gpu:',time.time()-now)
        inputs = []
        outputs = []
        inputs.append(grpcclient.InferInput('input', [img.shape[0], 3, 360, 640], "UINT8"))
        inputs[-1].set_shared_memory(shm_input1, input_byte_size)
        outputs.append(grpcclient.InferRequestedOutput('output0'))
        outputs[-1].set_shared_memory(shm_output0, output_byte_size_output0)
        outputs.append(grpcclient.InferRequestedOutput('461'))
        outputs[-1].set_shared_memory(shm_output1, output_byte_size_output1)
        outputs.append(grpcclient.InferRequestedOutput('460'))
        outputs[-1].set_shared_memory(shm_output2, output_byte_size_output2)
        wrapped_callback = partial(self.detec_triton_callback, request_id=request_id, user_data=None)
        self.events[request_id] = threading.Event()
        self.triton_client.async_infer(model_name='my_model',
                                        inputs=inputs,
                                        outputs=outputs,
                                        callback=wrapped_callback)   
nv-kmcgill53 commented 1 year ago

Hi @muqishan, One thing I see in your code is the following:

        now = time.time()
        print(img.dtype)
        **cudashm.set_shared_memory_region(shm_input_handle, [img])**
        print(int(round(time.time() * 1000)),' load date: cpu->gpu:',time.time()-now)

The print statement between now and set_shared_memory_region(shm_input_handle, [img]) will cause unwanted latency in your benchmarking. Could you please provide updated numbers with this print statement moved out of the timing capture?

@Tabrizian Do you have any suggestions based on the given information?

muqishan commented 1 year ago

I am writing to you with a request for assistance and guidance, and I have included a detailed test report describing the issues I am facing. Any help you can provide will be deeply appreciated.

Graphics Card Driver Version: 530.30.02 CUDA Version: 12.1 Triton Inference Server Version: R23.04 PythonTriton Client Version: 2.33.0

Initially, I used the bandwidthTest from the cuda-samples toolkit to test on my Ubuntu 20.04. Under ideal conditions (CPU and GPU not being occupied), the performance yielded was as follows, indicating a CPU-GPU speed of 13.2GB/S.

Running on...
 Device 0: NVIDIA GeForce RTX 3080
 Quick Mode
 Host to Device Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)    Bandwidth(GB/s)
   32000000         13.2
 Device to Host Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)    Bandwidth(GB/s)
   32000000         13.2
 Device to Device Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)    Bandwidth(GB/s)
   32000000         650.8
Result = PASS

1I eliminated all other interferences and only timed the data copying to the GPU shared memory. The modifications are as follows.

    def detec_triton_infer(self, img, request_id):
        memory_block = self.shared_memory_pool[request_id]
        shm_input_handle = memory_block['shm_input_handle']
        input_byte_size = memory_block['input_byte_size']
        output_byte_size_output0 = memory_block['output_byte_size_output0']
        shm_input1 = memory_block['shm_input1']
        shm_output0 = memory_block['shm_output0']
        now = time.time()
        cudashm.set_shared_memory_region(shm_input_handle, [img])
        print(int(round(time.time() * 1000)),' load date: cpu->gpu:',time.time()-now)

        inputs = []
        outputs = []
        inputs.append(grpcclient.InferInput('input', [img.shape[0], 1080, 1920, 3], "UINT8"))
        inputs[-1].set_shared_memory(shm_input1, input_byte_size)
        outputs.append(grpcclient.InferRequestedOutput("fc1"))
        outputs[-1].set_shared_memory(shm_output0, output_byte_size_output0)
        wrapped_callback = partial(self.detec_triton_callback, request_id=request_id, user_data=None)
        self.events[request_id] = threading.Event()
        self.triton_client.async_infer(model_name='my_model',
                                        inputs=inputs,
                                        outputs=outputs,
                                        callback=wrapped_callback)

2 I will conduct tests in batches, and each test will print 30 logs related to the transmission speed.

2.1. When the batch size is 8 with shape [8,1080,1920,3] and type uint8, theoretically it occupies 6.22MB*8 memory, approximately 50MB. The logs obtained show that on average, each data transmission takes 20ms. The time is relatively stable and only 2GB/S of bandwidth is used.

1686633832789  load date: cpu->gpu: 0.020900249481201172
complete inference time-consuming: 0.05977988243103027
1686633832868  load date: cpu->gpu: 0.04026508331298828
complete inference time-consuming: 0.0878148078918457
1686633832940  load date: cpu->gpu: 0.02418994903564453
complete inference time-consuming: 0.06948161125183105
1686633833019  load date: cpu->gpu: 0.034282684326171875
complete inference time-consuming: 0.0691215991973877
1686633833075  load date: cpu->gpu: 0.02093815803527832
complete inference time-consuming: 0.07282352447509766
1686633833168  load date: cpu->gpu: 0.04059743881225586
complete inference time-consuming: 0.07389640808105469
1686633833223  load date: cpu->gpu: 0.022196531295776367
complete inference time-consuming: 0.053911685943603516
1686633833277  load date: cpu->gpu: 0.02158379554748535
complete inference time-consuming: 0.051847219467163086
1686633833329  load date: cpu->gpu: 0.021717548370361328
complete inference time-consuming: 0.05392646789550781
1686633833386  load date: cpu->gpu: 0.024930715560913086
complete inference time-consuming: 0.06096243858337402
1686633833449  load date: cpu->gpu: 0.027566909790039062
complete inference time-consuming: 0.06314778327941895
1686633833524  load date: cpu->gpu: 0.03928852081298828
complete inference time-consuming: 0.07467532157897949
1686633833599  load date: cpu->gpu: 0.03951621055603027
complete inference time-consuming: 0.07245731353759766
1686633833654  load date: cpu->gpu: 0.0218813419342041
complete inference time-consuming: 0.05287790298461914
1686633833706  load date: cpu->gpu: 0.02106332778930664
complete inference time-consuming: 0.0732564926147461
1686633833783  load date: cpu->gpu: 0.02480006217956543
complete inference time-consuming: 0.07753276824951172
1686633833877  load date: cpu->gpu: 0.04100537300109863
complete inference time-consuming: 0.07381749153137207
1686633833932  load date: cpu->gpu: 0.021692276000976562
complete inference time-consuming: 0.08040213584899902
1686633834015  load date: cpu->gpu: 0.02469038963317871
complete inference time-consuming: 0.06208515167236328
1686633834089  load date: cpu->gpu: 0.036888837814331055
complete inference time-consuming: 0.06889867782592773
1686633834142  load date: cpu->gpu: 0.02089214324951172
complete inference time-consuming: 0.05150890350341797
1686633834195  load date: cpu->gpu: 0.022374629974365234
complete inference time-consuming: 0.07512307167053223
1686633834288  load date: cpu->gpu: 0.04026436805725098
complete inference time-consuming: 0.07222890853881836
1686633834341  load date: cpu->gpu: 0.020918846130371094
complete inference time-consuming: 0.05935311317443848
1686633834419  load date: cpu->gpu: 0.03951716423034668
complete inference time-consuming: 0.08470702171325684
1686633834499  load date: cpu->gpu: 0.03448677062988281
complete inference time-consuming: 0.07020950317382812
1686633834574  load date: cpu->gpu: 0.0395817756652832
complete inference time-consuming: 0.09308838844299316
1686633834668  load date: cpu->gpu: 0.04032182693481445
complete inference time-consuming: 0.07536983489990234
1686633834724  load date: cpu->gpu: 0.021258831024169922
complete inference time-consuming: 0.056568145751953125
1686633834781  load date: cpu->gpu: 0.021051406860351562
complete inference time-consuming: 0.056879281997680664
30 complete inferences time-consuming: 2.055421829223633

2.2. When the batch size is 16 with shape [16,1080,1920,3] and type uint8, theoretically it occupies 6.22MB*16 memory, approximately 100MB. The logs obtained show that on average, each data transmission takes 50ms. The time is extremely unstable and only 2GB/S of bandwidth is used.

1686633908030  load date: cpu->gpu: 0.04231667518615723
complete inference time-consuming: 0.08690452575683594
1686633908115  load date: cpu->gpu: 0.04111480712890625
complete inference time-consuming: 0.07975101470947266
1686633908202  load date: cpu->gpu: 0.04772663116455078
complete inference time-consuming: 0.0887141227722168
1686633908310  load date: cpu->gpu: 0.0671234130859375
complete inference time-consuming: 0.11153912544250488
1686633908417  load date: cpu->gpu: 0.06273436546325684
complete inference time-consuming: 0.10349845886230469
1686633908514  load date: cpu->gpu: 0.05564546585083008
complete inference time-consuming: 0.09279060363769531
1686633908592  load date: cpu->gpu: 0.04136252403259277
complete inference time-consuming: 0.08142733573913574
1686633908700  load date: cpu->gpu: 0.06740617752075195
complete inference time-consuming: 0.10413837432861328
1686633908779  load date: cpu->gpu: 0.04277610778808594
complete inference time-consuming: 0.08878469467163086
1686633908866  load date: cpu->gpu: 0.04125070571899414
complete inference time-consuming: 0.09403347969055176
1686633908981  load date: cpu->gpu: 0.0615999698638916
complete inference time-consuming: 0.10189104080200195
1686633909082  load date: cpu->gpu: 0.06094169616699219
complete inference time-consuming: 0.10097885131835938
1686633909182  load date: cpu->gpu: 0.0598607063293457
complete inference time-consuming: 0.09663796424865723
1686633909260  load date: cpu->gpu: 0.04122042655944824
complete inference time-consuming: 0.09121894836425781
1686633909368  load date: cpu->gpu: 0.05814337730407715
complete inference time-consuming: 0.0949544906616211
1686633909446  load date: cpu->gpu: 0.041066646575927734
complete inference time-consuming: 0.0808267593383789
1686633909531  load date: cpu->gpu: 0.04509425163269043
complete inference time-consuming: 0.08513021469116211
1686633909633  load date: cpu->gpu: 0.062006473541259766
complete inference time-consuming: 0.10178804397583008
1686633909733  load date: cpu->gpu: 0.06049346923828125
complete inference time-consuming: 0.09749317169189453
1686633909812  load date: cpu->gpu: 0.04132795333862305
complete inference time-consuming: 0.0785064697265625
1686633909890  load date: cpu->gpu: 0.04146862030029297
complete inference time-consuming: 0.07868599891662598
1686633909969  load date: cpu->gpu: 0.041143178939819336
complete inference time-consuming: 0.07804441452026367
1686633910047  load date: cpu->gpu: 0.04077005386352539
complete inference time-consuming: 0.07770729064941406
1686633910125  load date: cpu->gpu: 0.041449546813964844
complete inference time-consuming: 0.09257245063781738
1686633910221  load date: cpu->gpu: 0.044713735580444336
complete inference time-consuming: 0.08440661430358887
1686633910306  load date: cpu->gpu: 0.044945478439331055
complete inference time-consuming: 0.08175945281982422
1686633910384  load date: cpu->gpu: 0.04156684875488281
complete inference time-consuming: 0.07811427116394043
1686633910462  load date: cpu->gpu: 0.04137134552001953
complete inference time-consuming: 0.07818913459777832
1686633910540  load date: cpu->gpu: 0.04164743423461914
complete inference time-consuming: 0.07831072807312012
1686633910619  load date: cpu->gpu: 0.04211616516113281
complete inference time-consuming: 0.0791168212890625
30 complete inferences time-consuming: 2.675527811050415

2.3. During the inference execution, I also initiated the bandwidthTest for observation.

Running on...
 Device 0: NVIDIA GeForce RTX 3080
 Quick Mode
 Host to Device Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)    Bandwidth(GB/s)
   32000000         10.8
 Device to Host Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)    Bandwidth(GB/s)
   32000000         10.7
 Device to Device Bandwidth, 1 Device(s)
 PINNED Memory Transfers
   Transfer Size (Bytes)    Bandwidth(GB/s)
   32000000         606.6
Result = PASS

Queries:

1.Are the tests I'm performing valid? 2.For the bandwidthTest, the PCI bandwidth is approximately 13GB/S, but actual usage is around 2GB/S. How can I increase the utilization rate? 3.If I adopt the NVIDIA Jetson Xavier NX series or other modules for deployment and inference tasks, can I bypass the dependency on PCI?

In relation to this issue, the online resources are incredibly scarce, making it difficult for me to ascertain the cause of the problem. As a novice in deep learning deployment, I would like to express my profound gratitude once again to all the seniors who have been generous enough to offer their assistance.

muqishan commented 1 year ago

In the source code of set_shared_memory_region, I noticed the line input_value = np.ascontiguousarray(input_value).flatten() Since my images are read by OpenCV and stacked together using np.stack, my 'imgs' are contiguous in memory and in a row-major format. I tried commenting out this line of code, and I obtained the exact same results. I presume this is because the matrix after stacking with np.stack is also in a contiguous memory space and in a row-major order. Hence, in my project, there is no need to use ascontiguousarray and flatten on 'imgs'. Due to the nature of np.stack, 'imgs' are already contiguous, so directly obtaining their addresses should yield the same result as before. I compared the results very carefully before and after commenting out this line, and I'm confident that they are exactly the same. However, this line of code takes up half of the time in the entire data loading process from CPU to GPU. Perhaps the performance of a C++ client would be much better than Python, and I shouldn't insist on using Python for deployment.

def set_shared_memory_region(cuda_shm_handle, input_values):
    """Copy the contents of the numpy array into the cuda shared memory region.

    Parameters
    ----------
    cuda_shm_handle : c_void_p
        The handle for the cuda shared memory region.
    input_values : list
        The list of numpy arrays to be copied into the shared memory region.

    Raises
    ------
    CudaSharedMemoryException
        If unable to set values in the cuda shared memory region.
    """

    if not isinstance(input_values, (list, tuple)):
        _raise_error("input_values must be specified as a numpy array")
    for input_value in input_values:
        if not isinstance(input_value, (np.ndarray,)):
            _raise_error(
                "input_values must be specified as a list/tuple of numpy arrays"
            )

    offset_current = 0
    for input_value in input_values:
        # print('input_values',len(input_values),input_value.shape)
        input_value = np.ascontiguousarray(input_value).flatten()
        if input_value.dtype == np.object_:
            input_value = input_value.item()
            byte_size = np.dtype(np.byte).itemsize * len(input_value)
            _raise_if_error(
                c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \
                    c_uint64(byte_size), cast(input_value, c_void_p))))
        else:
            byte_size = input_value.size * input_value.itemsize
            _raise_if_error(
                c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \
                    c_uint64(byte_size), input_value.ctypes.data_as(c_void_p))))
        offset_current += byte_size
    return
nv-kmcgill53 commented 1 year ago

1.Are the tests I'm performing valid?

I believe your methodology is correct; however, I missed looking into the set_shared_memory_region. As you have stated, your image is a numpy array which will be of type np.object_. This will cause the first branch to execute. Here there is a line input_value = input_value.item() which is a copy operation. I believe this will add to your latency as well since I believe you will be copying 50 and 100 MB respectively in your tests. It might be worth setting the timer inside this function instead to get the true value of the device copy:

if input_value.dtype == np.object_:
            input_value = input_value.item()
            byte_size = np.dtype(np.byte).itemsize * len(input_value)

            now = time.time()
            _raise_if_error(
                c_int(_ccudashm_shared_memory_region_set(cuda_shm_handle, c_uint64(offset_current), \
                    c_uint64(byte_size), cast(input_value, c_void_p))))
            print(int(round(time.time() * 1000)),' load date: cpu->gpu:',time.time()-now)

        else:
            ....

2.For the bandwidthTest, the PCI bandwidth is approximately 13GB/S, but actual usage is around 2GB/S. How can I increase the utilization rate?

I am not a PCIe expert in this regard, perhaps @tanmayv25 you have some insights?

3.If I adopt the NVIDIA Jetson Xavier NX series or other modules for deployment and inference tasks, can I bypass the dependency on PCI?

In theory perhaps, but Jetson devices are lower power and meant for edge applications. I suspect whatever latency boost you get from copying you will lose in actual inference time.

Perhaps the performance of a C++ client would be much better than Python, and I shouldn't insist on using Python for deployment.

In general yes, the C++ client will be faster (I would benchmark this solution as well to ensure the solution is right for you). @jbkyang-nvi Might be able to provide specifics here.

muqishan commented 1 year ago
  1. First, if we comment out the line input_value = np.ascontiguousarray(input_value).flatten(), then input_value.dtype == np.object_ will be False, where the value of input_value.dtype is uint8. Even if we don't comment out this line, input_value.dtype == np.object_ will still be False. Therefore, input_value = input_value.item() will not be executed here, and it will enter another branch.
  2. The length of the input_values list is 1, and its value shape is [batch,1920,1080,3]. In the other branch, I didn't find any particularly time-consuming Python code. Regarding the call to CUDA shared memory, it should be calling C++ primitives.
  3. According to the bandwidthTest tool from NVIDIA, a speed of 13.2GB/S is obtained. Even if I comment out input_value = np.ascontiguousarray(input_value).flatten(), the improvement here only increases the actual bandwidth to about 5GB/S. This performance is not much different from PyTorch's tensor.to('cuda'), and it's even a bit lower than tensor.to('cuda'). But this is still unreasonable because the actual maximum bandwidth of bandwidthTest is 13.2GB/S.
  4. During all the tests, I had almost no other CPU overhead. Regarding the parallel test in 2.3. "During the inference execution, I also initiated the bandwidthTest for observation", I only retained the test results of bandwidthTest. The tests in 2.1 and 2.2 were conducted with almost no other overhead. I think this time is very abnormal. Even considering the CPU and other issues, there should not be such a long time overhead here. As a newcomer to deep learning deployment, I would like to once again express my deep gratitude to all the predecessors who generously provided help.
tanmayv25 commented 1 year ago

@muqishan I would recommend profiling the cudashm.set_shared_memory_region() and torch.to('cuda') calls using nsight systems. It will give you finer visibility into the actual bottlenecks and overheads involved. You should be watching the HtoD copies and any activity around it.

dyastremsky commented 1 year ago

Closing due to inactivity. Please let us know if you would like us to re-open this issue for follow-up.