triton-inference-server / server

The Triton Inference Server provides an optimized cloud and edge inferencing solution.
https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html
BSD 3-Clause "New" or "Revised" License
7.82k stars 1.42k forks source link

Triton server crash when running a large model with an ONNX/CPU backend #7337

Open LucasAudebert opened 1 month ago

LucasAudebert commented 1 month ago

Description

I encounter a crash when I am using big model with ONNX backend on CPU. The problem seems to be related to this closed ticket: https://github.com/triton-inference-server/server/issues/5702

My models run fine on my computer CPU when I am using the ONNX runtime without triton. But when I embed them on a triton server, I get very high RAM consumption (much higher than without triton server) which seems to cause the crash of my server:

image

Moreover, when I reduce the batch size, I can run my model several times, but RAM doesn't seem to be completely released and I end up with a crash:

image

Triton Information I am working with triton server 24.04 (I also test some previous versions).

The problem is present on triton container and on the builds that I made.

To Reproduce I am working on WSL but I also can reproduce the problem directly on Windows.

I share you a model repository that should allow you to reproduce my crash. I use OneDrive because the model is too big to can be attached but I can use other sharing solutions if you prefer.

docker run -it --shm-size=1g --rm -p8000:8000 -p8001:8001 -p8002:8002 -v $(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:24.04-py3
tritonserver --model-repository=/models --log-verbose=2
import numpy as np
import tritonclient.http as httpclient

ITERATION_NB = 10
BATCH_SIZE = 384

if __name__ == '__main__':
    client = httpclient.InferenceServerClient(url='localhost:8000')

    for _ in range(ITERATION_NB):
        array_input = np.random.rand(BATCH_SIZE, 384, 384, 1).astype(np.float32)
        inference_input = httpclient.InferInput('input_1', array_input.shape, datatype='FP32')
        inference_input.set_data_from_numpy(array_input, binary_data=True)

        response = client.infer(
            model_name='my_model', inputs=[inference_input]
        )

        assert response.as_numpy('conv2d_29') is not None

Expected behavior I have 7 models of similar size to the model I've shared. I have a batch size of 384. I have pre-processing and post-processing steps common to all these models. I'd like to be able to perform all these steps in a single ensemble model. Do you think this is possible with triton server on CPU ? If not, how can I get closer ? Today, I can do all that on my computer but without using triton server.

statiraju commented 1 month ago

[6859] created to track issue