triton-inference-server / server

The Triton Inference Server provides an optimized cloud and edge inferencing solution.
https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html
BSD 3-Clause "New" or "Revised" License
7.99k stars 1.44k forks source link

How to control pipeline of the ensemble model to moving data on only one GPU when using multi GPUs ? #7001

Open lzcchl opened 5 months ago

lzcchl commented 5 months ago

As mentioned the end of https://github.com/triton-inference-server/server/issues/6981

triton: nvcr.io/nvidia/tritonserver 23.12-py3

I have 4 GPUs, and my model is ensemble model, I don't set gpus in instance_group which in config.pbtxt, so it will creat instance on each GPU default.

my bash to launch triton server is below, in which I specify available gpus explicitly. sudo docker run --gpus '"device=0,1,2"' --rm -it --net host -v $PWD/model_repo:/models serving-triton tritonserver --model-repository /models

when I set " --gpus '"device=0"' " or " --gpus '"device=1"' " or " --gpus '"device=2"' " or " --gpus '"device=3"' " to run triton server, my result is very well all the time, but when I set multi gpu " --gpus '"device=0,1"' " or " --gpus '"device=0,2"' " or ... or ... or ........ or " --gpus '"device=0,1,2,3"' ", the result is not right always. sometime the result is right, sometime the result is not.

In this case, I believe that the data of the ensemble model is not executed on the same device by default. is that any parameters shoule be set in config.pbtxt ? I don't have any idea about that after some tests, help me, please~ thank you~

MatthieuToulemont commented 5 months ago

This issue is very clear when you try running BLS models on instances with multiple GPUs where your BLS runs some PyTorch code and interacts with TensorRT models .

indrajit96 commented 5 months ago

CC @Tabrizian

lzcchl commented 5 months ago

@MatthieuToulemont Thanks for reply, but you don't show me how to solve it.

In addition, my model is ensemble model rather than BLS model, my backend is write by C++ and cuda.

In my experiment, I just remove TensorRT models for keep a simplest ensemble model which include custom C++ backend "rgb2bgr" and custom C++ backend "nhwc2nchw", only preprocess no any infer model.

In my custom C++ backend, the code for receiving and sending data is as follows:

BackendInputCollector collector(
    requests, request_count, &responses, model_state->TritonMemoryManager(),
    model_state->EnablePinnedInput() /* pinned_enabled */, instance_state->CudaStream() /* stream*/);

std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> allowed_input_types =
    {{TRITONSERVER_MEMORY_GPU, deviceId}};

===============================================================

BackendOutputResponder responder(
    requests, request_count, &responses, model_state->TritonMemoryManager(),
    supports_first_dim_batching, model_state->EnablePinnedOutput() /* pinned_enabled */,
    instance_state->CudaStream() /* stream*/);

responder.ProcessTensor(
    model_state->OutputTensorName().c_str(), model_state->TensorDataType(),
    tensor_shape, output_buffer, output_buffer_memory_type,
    output_buffer_memory_type_id);

const bool need_cuda_output_sync = responder.Finalize();
if (need_cuda_output_sync) {
  cudaStreamSynchronize(instance_state->CudaStream());
}
MatthieuToulemont commented 5 months ago

Sorry, @lzcchl my message was intended for Nvidia collaborators.

On my side, the fix I found amounts to getting the device id from the args in the initialize function of the python model. I then ensure that for a given GPU I always maintain the tensors on that GPU by using .to(f"cuda:{device_id}")

lzcchl commented 5 months ago

@MatthieuToulemont thank you, I think it is similar between C++ and python, output_buffer_memory_type_id be set int function responder.ProcessTensor in C++ is similar to ".to(f"cuda:{device_id}")" in python.

I think the output data from GPU0 should be send to GPU0 for next step in ensemble model, but in actually, the output data also be send to GPU1/2/3, which make data moving and the data isn't right in my test, this is my question.

lzcchl commented 5 months ago

after some experiment, I find it work well when I set output to CPU explicit, that's TRITONSERVER_MemoryType output_buffer_memory_type = TRITONSERVER_MEMORY_CPU; responder.ProcessTensor( model_state->OutputTensorName().c_str(), model_state->TensorDataType(), tensor_shape, output_buffer, output_buffer_memory_type, output_buffer_memory_type_id);

so, the next part of ensemble model could get correct data.

but, in my case, I want data move on one GPU when I have multi GPUs, when data select GPU0, it should not be sent to GPU1/2/3, when data select GPU1, it should not be sent to GPU0/2/3, what's suggest for me?