triton-inference-server / dali_backend

The Triton backend that allows running GPU-accelerated data pre-processing pipelines implemented in DALI's python API.
https://docs.nvidia.com/deeplearning/dali/user-guide/docs/index.html
MIT License
123 stars 29 forks source link

Segfault when max_batch_size > 1 #106

Closed MaxHuerlimann closed 2 years ago

MaxHuerlimann commented 3 years ago

Hi everybody

I am facing issues when enabling dynamic scheduler with a max_batch_size bigger than 1, which gives me a segfault when submitting requests. In the main readme it says, that dali requires homogenous batch sizes. How would I achieve that when using the triton C API directly? In the tests introduced with the PR enabling dynamic batching, I can't find anything enforcing homogenous batch sizes. Am I missing something?

We are using the C API of the triton r21.06 release with a dali pipeline which is created with a batch size of 64 and then set the max_batch_size in the triton config.pbtxt file to 32 for all elements of the ensemble model.

szalpal commented 3 years ago

Hello @MaxHuerlimann !

First of all let me clarify:

In the main readme it says, that dali requires homogenous batch sizes.

I believe you're referring to the Known limitations section. It's actually Triton, that requires homogeneous batch shape, as is written there. DALI is fine with supporting different shapes for each sample in the batch, as long as the number of dimensions remains constant.

It's hard guess what can be your problem without some insides. Could you provide the server log when the segfault happens? If you pass --log-verbose=1 option when running the server, you can find some useful logging information there, it might come in handy. You may also refer to the issue https://github.com/triton-inference-server/dali_backend/issues/104, which is about a similar topic. Since you're using Triton's C API, it might be the problem of incorrectly putting the request together, but that's only my guess.

Anyway, should you like to provide some more info about the error, I'd be happy to help. To be perfectly honest, we didn't test using DALI Backend via Triton's C API thoroughly yet, so there might be some bug wandering around.

MaxHuerlimann commented 3 years ago

Thanks for the quick response and the clarifications!

The verbose log right before the segfault is:

I1025 10:07:18.299944 5861 model_repository_manager.cc:638] GetInferenceBackend() 'ensemble_model' version -1
I1025 10:07:18.300521 5861 infer_request.cc:524] prepared: [0x0x7ff88c483300] request id: , model: ensemble_model, requested version: -1, actual version: 1, flags: 0x0, correlation id: 0, batch size: 1, priority: 0, timeout (us): 0
original inputs:
[0x0x7ff88c63e088] input: ROTATION_ANGLE, type: FP32, original shape: [1,1], batch + shape: [1,1], shape: [1]
[0x0x7ff88c491da8] input: ROI, type: FP32, original shape: [1,4], batch + shape: [1,4], shape: [4]
[0x0x7ff88c650538] input: IMAGE, type: UINT8, original shape: [1,24084], batch + shape: [1,24084], shape: [24084]
override inputs:
inputs:
[0x0x7ff88c650538] input: IMAGE, type: UINT8, original shape: [1,24084], batch + shape: [1,24084], shape: [24084]
[0x0x7ff88c491da8] input: ROI, type: FP32, original shape: [1,4], batch + shape: [1,4], shape: [4]
[0x0x7ff88c63e088] input: ROTATION_ANGLE, type: FP32, original shape: [1,1], batch + shape: [1,1], shape: [1]
original requested outputs:
requested outputs:
output

I1025 10:07:18.300562 5861 model_repository_manager.cc:638] GetInferenceBackend() 'dali_pipeline' version -1
I1025 10:07:18.300569 5861 model_repository_manager.cc:638] GetInferenceBackend() 'model' version -1
I1025 10:07:18.300591 5861 infer_request.cc:524] prepared: [0x0x7ff88c485c80] request id: , model: dali_pipeline, requested version: -1, actual version: 1, flags: 0x0, correlation id: 0, batch size: 1, priority: 0, timeout (us): 0
original inputs:
[0x0x7ff88c56a9c8] input: IMAGE, type: UINT8, original shape: [1,24084], batch + shape: [1,24084], shape: [24084]
[0x0x7ff88c5774a8] input: ROI, type: FP32, original shape: [1,4], batch + shape: [1,4], shape: [4]
[0x0x7ff88c579038] input: ROTATION_ANGLE, type: FP32, original shape: [1,1], batch + shape: [1,1], shape: [1]
override inputs:
inputs:
[0x0x7ff88c579038] input: ROTATION_ANGLE, type: FP32, original shape: [1,1], batch + shape: [1,1], shape: [1]
[0x0x7ff88c5774a8] input: ROI, type: FP32, original shape: [1,4], batch + shape: [1,4], shape: [4]
[0x0x7ff88c56a9c8] input: IMAGE, type: UINT8, original shape: [1,24084], batch + shape: [1,24084], shape: [24084]
original requested outputs:
PREPROCESSED_IMAGE
requested outputs:
PREPROCESSED_IMAGE

Then the segfault message itself:

* thread #41, name = 'dotnet', stop reason = signal SIGSEGV: invalid address (fault address: 0x0)
    frame #0: 0x00007ff96e8e5a84 libdali_operators.so`float dali::OpSpec::GetArgumentImpl<float, float>(std::string const&, dali::ArgumentWorkspace const*, long) const + 244
libdali_operators.so`dali::OpSpec::GetArgumentImpl<float, float>:
->  0x7ff96e8e5a84 <+244>: movss  (%rax), %xmm0             ; xmm0 = mem[0],zero,zero,zero
    0x7ff96e8e5a88 <+248>: addq   $0x88, %rsp
    0x7ff96e8e5a8f <+255>: popq   %rbx
    0x7ff96e8e5a90 <+256>: popq   %rbp

Does that give you any relevant information?

JanuszL commented 3 years ago

Hi @MaxHuerlimann,

It looks like one of the operators gets an invalid argument (nullptr instead of the valid data). I would check how ROI and ROI. If you could provide a minimal, self-contained reproduction code we can run on our side it would be great.

MaxHuerlimann commented 3 years ago

We run triton with a proprietary C# wrapper around the C API, so I unfortunately can't just share our code for this. But I'll try setting up something to reproduce the issue with only the C API.

MaxHuerlimann commented 2 years ago

I will close this for now, as I don't have the capacity to reproduce this with extra code (as can be seen by the long inactivity) and the inference latency does not seem to drastically impacted. I will reopen this once I can tackle the issue again.

MaxHuerlimann commented 2 years ago

Hello again!

I have come back to this issue now as we are experimenting with using the docker deployment of triton (22.05) and we are still facing this issue. I have managed to pinpoint it to the crop operator. If I try to feed it a batch of crop windows (as we are detecting objects in an image and want to crop them on a per-image basis), the triton process crashes with

Signal (11) received.
 0# 0x0000558BBBD771B9 in tritonserver
 1# 0x00007F886FFD80C0 in /usr/lib/x86_64-linux-gnu/libc.so.6
 2# float dali::OpSpec::GetArgumentImpl<float, float>(std::string const&, dali::ArgumentWorkspace const*, long) const in /opt/tritonserver/backends/dali/dali/libdali_operators.so
 3# 0x00007F86D2B4826E in /opt/tritonserver/backends/dali/dali/libdali_operators.so
 4# 0x00007F86D25D1F76 in /opt/tritonserver/backends/dali/dali/libdali_operators.so
 5# 0x00007F86D2597B12 in /opt/tritonserver/backends/dali/dali/libdali_operators.so
 6# void dali::Executor<dali::AOT_WS_Policy<dali::UniformQueuePolicy>, dali::UniformQueuePolicy>::RunHelper<dali::DeviceWorkspace>(dali::OpNode&, dali::DeviceWorkspace&) in /opt/tritonserver/backends/dali/dali/libdali.so
 7# dali::Executor<dali::AOT_WS_Policy<dali::UniformQueuePolicy>, dali::UniformQueuePolicy>::RunGPUImpl() in /opt/tritonserver/backends/dali/dali/libdali.so
 8# dali::Executor<dali::AOT_WS_Policy<dali::UniformQueuePolicy>, dali::UniformQueuePolicy>::RunGPU() in /opt/tritonserver/backends/dali/dali/libdali.so
 9# 0x00007F884537E228 in /opt/tritonserver/backends/dali/dali/libdali.so
10# 0x00007F88453F78BC in /opt/tritonserver/backends/dali/dali/libdali.so
11# 0x00007F88459DAB6F in /opt/tritonserver/backends/dali/dali/libdali.so
12# 0x00007F88715D7609 in /usr/lib/x86_64-linux-gnu/libpthread.so.0
13# clone in /usr/lib/x86_64-linux-gnu/libc.so.6

Is there a recommended way how to feed a batch of cropping windows to a crop a batch of images with?

A minimal example for reproduction should be:

import nvidia.dali.fn as fn
from nvidia.dali import pipeline_def

@pipeline_def(batch_size=32, num_threads=4, device_id=0)
def pipeline():
    images = fn.external_source(device="cpu", name="IMAGE")
    crop_x = fn.external_source(device="cpu", name="CROP_X")
    crop_y = fn.external_source(device="cpu", name="CROP_Y")
    crop_width = fn.external_source(device="cpu", name="CROP_WIDTH")
    crop_height = fn.external_source(device="cpu", name="CROP_HEIGHT")

    images = fn.decoders.image(images, device="mixed")
    images = fn.crop(
        images,
        crop_pos_x=crop_x,
        crop_pos_y=crop_y,
        crop_w=crop_width,
        crop_h=crop_height
    )
    images = fn.resize(
        images,
        resize_x=288,
        resize_y=384,
        mode="not_larger",
    )
    images = fn.pad(images, fill_value=128, axes=(0, 1), shape=(384, 288))
    return images

def main():
    pipeline().serialize(filename='1/model.dali')

if __name__ == "__main__":
    main()

and with configuration

name: "dali_test"
backend: "dali"
max_batch_size: 32
dynamic_batching {
  preferred_batch_size: [ 32 ]
  max_queue_delay_microseconds: 500
}
instance_group [
        {
                count: 1
                kind: KIND_GPU
        }
]
input [
        {
                name: "IMAGE"
                data_type: TYPE_UINT8
                dims: [ -1 ]
                allow_ragged_batch: true
        },
        {
                name: "CROP_X"
                data_type: TYPE_FP32
                dims: [ 1 ]
        },
        {
                name: "CROP_Y"
                data_type: TYPE_FP32
                dims: [ 1 ]
        },
        {
                name: "CROP_WIDTH"
                data_type: TYPE_FP32
                dims: [ 1 ]
        },
        {
                name: "CROP_HEIGHT"
                data_type: TYPE_FP32
                dims: [ 1 ]
        }
]
output [
        {
                name: "PREPROCESSED_IMAGE"
                data_type: TYPE_FP32
                dims: [ 3, 384, 288 ]
        }
]
szalpal commented 2 years ago

Hi @MaxHuerlimann !

Apologies for the late response. I've tried to reproduce your error, but I'm having hard time with it. So far I confirmed, that the DALI Pipeline you've wrote is correct and the config.pbtxt also seems to be correct. That would mean, that either the problem lies in the way you're feeding the input data to the server, or there is in fact some bug in DALI/DALI Backend. Would you mind providing some more information, how are you passing the data to the server? To be specific, I'd be grateful if you could post some Python Client code.

As a second verification, I've put together a client code, which should work with the DALI Pipeline and model configuration you've provided. Would you mind checking this code out and possibly also running it on your side? If my repro_client.py shows the same error as your application, we'd have more luck narrowing down the issue. The repro_client.py is a combination of two files, that we already have in DALI Backend repo: multi_input_client.py and dali_grpc_client.py. The former shows, how to work with scalar inputs (which you pass an CROP_... inputs) and the latter shows, how to feed images to triton server.

To run this file you should call:

python repro_client.py --model_name dali_test --img_dir images --batch_size 2 --n_iter 1

Where images is a directory containing two jpegs.

And here's the repro_client.py file: https://gist.github.com/szalpal/63d427249faab0f1b9087059ae394d58

MaxHuerlimann commented 2 years ago

I have used the perf_analyzer tool and used this data repro_data.zip with a batch size of 1 of each request and testing different concurrency values, doesn't really matter which one as it happens all the time.

I can check out if I can reproduce the issue with your repro client, will get back to you.

MaxHuerlimann commented 2 years ago

Yeah seems to be the same issue with the code you provided.

MaxHuerlimann commented 2 years ago

As a detail, this does not happen with for example the rotation operator. I can feed different scalars that get batched without an issue by the dynamic batcher.

szalpal commented 2 years ago

@MaxHuerlimann ,

thank you for checking this out. It's possible, that we have a bug of some sort there. Let me check this out and I'll get back to you as soon as I know something more.

MaxHuerlimann commented 2 years ago

Hi @szalpal, any updates regarding this issue? DALI has become a bit of a bottleneck now on our end so being able to use dynamic batching would be a great benefit for us.

szalpal commented 2 years ago

@MaxHuerlimann ,

that's actually one challenging debugging, but I'm working on it right now. Hopefully I'd have some conclusion in a day or two :)

szalpal commented 2 years ago

@MaxHuerlimann ,

we've narrowed down the issue and fixed it. Here's the PR: https://github.com/NVIDIA/DALI/pull/4043

The change will be released in Triton 22.08.

szalpal commented 2 years ago

Fixed in https://github.com/NVIDIA/DALI/pull/4045