triton-inference-server / tensorrtllm_backend

The Triton TensorRT-LLM Backend
Apache License 2.0
657 stars 94 forks source link

Unable to launch triton server for 8 gpu mistral model #275

Open nikhilshandilya opened 8 months ago

nikhilshandilya commented 8 months ago

I'm trying to run inference with mistral 7b model on triton, however I am running into issues when I try to launch the server from my image. I suspect its an issue with some mpi and triton shared libraries because I can load the model in the same container but not when i try to launch triton server.

Context To be specific:

I first used the NGC container image with tag 23.12-trtllm-python-py3. Then built TensorRT LLM with the repo that was initialized as the submodule.

(cd tensorrt_llm &&
    bash docker/common/install_cmake.sh &&
    export PATH=/usr/local/cmake/bin:$PATH &&
    python3 ./scripts/build_wheel.py --trt_root="/usr/local/tensorrt" &&
    pip3 install ./build/tensorrt_llm*.whl)

Then I serialized mistral model using 1gpu and 8 gpus with the provided build script

python build.py --model_dir <> \
                --dtype float16 \
                --remove_input_padding \
                --use_gpt_attention_plugin float16 \
                --enable_context_fmha \
                --use_gemm_plugin float16 \
                --output_dir <> \
                --paged_kv_cache \
                --world_size 1 \
                --tp_size 1 \
                --max_beam_width 1 \
                --max_batch_size 1 \
                --max_input_len 256 \
                --max_output_len 25 \
                --parallel_build

I then wrote triton model.py code that loads the engine wiht trt llm runtime python bindings. Something like

model_config = ModelConfig(
            vocab_size=vocab_size,
            num_layers=num_layers,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            hidden_size=hidden_size,
            paged_kv_cache=paged_kv_cache,
            tokens_per_block=tokens_per_block,
            gpt_attention_plugin=use_gpt_attention_plugin,
            remove_input_padding=remove_input_padding,
            use_custom_all_reduce=use_custom_all_reduce,
            dtype=dtype,
            quant_mode=quant_mode,
        )

        runtime_mapping = Mapping(
            world_size, runtime_rank, tp_size=tp_size, pp_size=pp_size
        )
        torch.cuda.set_device(self.rank % runtime_mapping.gpus_per_node)
        engine_name = self._get_engine_name(
            model_name, dtype, tp_size, pp_size, self.rank
        )
        serialize_path = os.path.join(model_dir, engine_name)
        with open(serialize_path, "rb") as f:
            engine_buffer = f.read()
        self.decoder = GenerationSession(model_config, engine_buffer, runtime_mapping)

I can load the model and run inference on my host and know that the engine is functional, however I am not able to load the 8GPU model in the triton server. I get the following error:

Error:

It looks like orte_init failed for some reason; your parallel process is
likely to abort.  There are many reasons that a parallel process can
fail during orte_init; some of which are due to configuration or
environment problems.  This failure appears to be an internal failure;
here's some additional information (which may only be relevant to an
Open MPI developer):

  getting local rank failed
  --> Returned value No permission (-17) instead of ORTE_SUCCESS

An error occurred in MPI_Init_thread
*** on a NULL communicator
*** MPI_ERRORS_ARE_FATAL (processes in this communicator will now abort,
***    and potentially your MPI job)

I suspect there is some issue with MPI and Triton because I can load the single GPU model and run inference on triton using triton server --model-repository <> --exit-on-error true --grpc-port 5001 --http-port 5000 --strict-readiness True --load-model <> --model-control-mode explicit --exit-timeout-secs=3

but cannot load the same model with python /code/saint/tensorrtllm_backend/scripts/launch_triton_server.py --world_size=1 --model_repo=/code/tensorrt_llm/triton/ --grpc_port=5001 --http_port=5000 which is very strange.

Which is not expected behavior because the python script should be invoking the tritonserver command under the hood?

nikhilshandilya commented 8 months ago

I resolved the above issue by setting instance type to KIND_CPU instead of KIND_GPU and moved decoder loading logic into the main triton class's initialize() method. I was using another class to load the engine within the triton class which I guess was causing issues?

But now i run into a sporadic issue where server complains that the GRPC port is already in use

E0109 19:34:35.165873190   11729 chttp2_server.cc:1080]      UNKNOWN:No address added out of total 1 resolved for '0.0.0.0:5001' {created_time:"2024-01-09T19:34:35.165818739+00:00", children:[UNKNOWN:Failed to add any wildcard listeners {created_time:"2024-01-09T19:34:35.165810938+00:00", children:[UNKNOWN:Unable to configure socket {created_time:"2024-01-09T19:34:35.165779878+00:00", fd:203, children:[UNKNOWN:Address already in use {syscall:"bind", os_error:"Address already in use", errno:98, created_time:"2024-01-09T19:34:35.165753947+00:00"}]}, UNKNOWN:Unable to configure socket {fd:203, created_time:"2024-01-09T19:34:35.165806998+00:00", children:[UNKNOWN:Address already in use {created_time:"2024-01-09T19:34:35.165803368+00:00", errno:98, os_error:"Address already in use", syscall:"bind"}]}]}]}
E0109 19:34:35.166155 11729 main.cc:245] failed to start GRPC service: Unavailable - Socket '0.0.0.0:5001' already in use 
--------------------------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
--------------------------------------------------------------------------
I0109 19:34:37.349191 11876 pb_stub.cc:1963]  Non-graceful termination detected. 
I0109 19:34:37.371437 11730 server.cc:307] Waiting for in-flight requests to complete.
I0109 19:34:37.371464 11730 server.cc:323] Timeout 30: Found 0 model versions that have in-flight inferences
I0109 19:34:37.371609 11730 server.cc:338] All models are stopped, unloading models
I0109 19:34:37.371615 11730 server.cc:345] Timeout 30: Found 1 live models and 0 in-flight non-inference requests
Signal (15) received.

I have tried changing the port and made sure its not being used by another process. Seems like p0 process with rank0 is able to host the server but other rank processes are not able to.

Can someone help explain how to bypass this? Related question: https://github.com/triton-inference-server/tensorrtllm_backend/issues/243