triton-inference-server / tensorrtllm_backend

The Triton TensorRT-LLM Backend
Apache License 2.0
670 stars 97 forks source link

[PROBLEM] Python thread does not release the GIL when using tensorrt_llm_bls in streaming mode #308

Open ZihanLiao opened 8 months ago

ZihanLiao commented 8 months ago

Problem: I added some lines of code to make the server support early stopping. Following is my model.py

import json
import traceback
from collections import OrderedDict
import numpy as np
import triton_python_backend_utils as pb_utils
class TritonPythonModel:

    def initialize(self, args):

        # Parse model configs
        model_config = json.loads(args['model_config'])

        params = model_config['parameters']

        accumulate_tokens_str = ''
        if 'accumulate_tokens' in params:
            accumulate_tokens_str = params['accumulate_tokens']['string_value']

        self.accumulate_tokens = accumulate_tokens_str.lower() in [
            'true', 'yes', '1', 't'
        ]

        self.decoupled = pb_utils.using_decoupled_model_transaction_policy(
            model_config)

        self.logger = pb_utils.Logger

        self.repetition_memory = OrderedDict()

        self.bls_input_tensor_names = [
            "text_input", "max_tokens", "bad_words", "stop_words", "end_id",
            "pad_id", "top_k", "top_p", "temperature", "length_penalty",
            "repetition_penalty", "min_length", "presence_penalty",
            "random_seed", "return_log_probs", "beam_width", "stream",
            "prompt_embedding_table", "prompt_vocab_size",
            "embedding_bias_words", "embedding_bias_weights"
        ]

        self.preproc_input_to_bls_input_map = {
            "QUERY": "text_input",
            "REQUEST_OUTPUT_LEN": "max_tokens",
            "BAD_WORDS_DICT": "bad_words",
            "STOP_WORDS_DICT": "stop_words",
            "EMBEDDING_BIAS_WORDS": "embedding_bias_words",
            "EMBEDDING_BIAS_WEIGHTS": "embedding_bias_weights"
        }

        self.preproc_output_to_trtllm_input_map = {
            "INPUT_ID": "input_ids",
            "REQUEST_INPUT_LEN": "input_lengths",
            "REQUEST_OUTPUT_LEN": "request_output_len",
            "BAD_WORDS_IDS": "bad_words_list",
            "STOP_WORDS_IDS": "stop_words_list",
            "EMBEDDING_BIAS": "embedding_bias",
        }

        self.trtllm_input_to_bls_input_map = {
            "end_id": "end_id",
            "pad_id": "pad_id",
            "beam_width": "beam_width",
            "runtime_top_k": "top_k",
            "runtime_top_p": "top_p",
            "len_penalty": "length_penalty",
            "repetition_penalty": "repetition_penalty",
            "min_length": "min_length",
            "presence_penalty": "presence_penalty",
            "random_seed": "random_seed",
            "return_log_probs": "return_log_probs",
            "streaming": "stream",
            "prompt_embedding_table": "prompt_embedding_table",
            "prompt_vocab_size": "prompt_vocab_size",
        }

        self.trtllm_output_to_postproc_input_map = {
            "output_ids": "TOKENS_BATCH",
            "sequence_length": "SEQUENCE_LENGTH",
            "cum_log_probs": "CUM_LOG_PROBS",
            "output_log_probs": "OUTPUT_LOG_PROBS",
        }

        self.postproc_output_to_bls_output_map = {
            "OUTPUT": "text_output",
            "OUT_CUM_LOG_PROBS": "cum_log_probs",
            "OUT_OUTPUT_LOG_PROBS": "output_log_probs",
        }

    def _get_bls_input_tensors_map(self, request):

        bls_input_tensors_map = {}
        for input_tensor_name in self.bls_input_tensor_names:
            tensor = pb_utils.get_input_tensor_by_name(request,
                                                       input_tensor_name)
            if tensor != None:
                bls_input_tensors_map[input_tensor_name] = tensor

        return bls_input_tensors_map

    def _get_preproc_input_tensors(self, bls_input_tensors_map):

        preproc_input_tensors = []

        for preproc_name, bls_name in self.preproc_input_to_bls_input_map.items(
        ):

            if bls_name in bls_input_tensors_map:
                tensor = bls_input_tensors_map[bls_name]
                # Change the name to what the preprocessor expects
                preproc_input_tensors.append(
                    pb_utils.Tensor(preproc_name, tensor.as_numpy()))

        return preproc_input_tensors

    def _get_trtllm_input_tensors(self, bls_input_tensors_map,
                                  preproc_output_tensors):

        trtllm_input_tensors = []

        # Set input tensors from preprocessor outputs
        for preproc_output_tensor in preproc_output_tensors:

            trtllm_tensor_name = self.preproc_output_to_trtllm_input_map[
                preproc_output_tensor.name()]
            trtllm_input_tensors.append(
                pb_utils.Tensor(trtllm_tensor_name,
                                preproc_output_tensor.as_numpy()))

            self.logger.log_info(f"Interface name: {trtllm_tensor_name}, np.array {preproc_output_tensor.as_numpy()}")
        # Set input tensors from bls inputs
        for trtllm_name, bls_name in self.trtllm_input_to_bls_input_map.items(
        ):

            if bls_name in bls_input_tensors_map:
                tensor = bls_input_tensors_map[bls_name]
                # Change the name to what the preprocessor expects
                trtllm_input_tensors.append(
                    pb_utils.Tensor(trtllm_name, tensor.as_numpy()))

        return trtllm_input_tensors

    def _get_postproc_input_tensors(self, tokens, trtllm_output_tensors):

        postproc_input_tensors = []

        for trtllm_output_tensor in trtllm_output_tensors:

            # If in decoupled mode, option to append new tokens to existing tokens before calling postprocessor
            # This might be needed for some tokenizers
            # Note that in that case, the client must overwrite previously received output text
            if (self.accumulate_tokens and self.decoupled
                    and trtllm_output_tensor.name() == "output_ids"):

                new_tokens = trtllm_output_tensor.as_numpy()
                self.logger.log_info(f"New tokens: {new_tokens}")
                if new_tokens.ndim != 3:
                    raise pb_utils.TritonModelException(
                        "Expected output_ids tensor to have 3 dims.")
                if new_tokens.shape[0] != 1:
                    raise pb_utils.TritonModelException(
                        "Expected output_ids tensor to have batch size of 1")
                if new_tokens.shape[1] != 1:
                    raise pb_utils.TritonModelException(
                        "Accumulation of tokens is only implemented for beam width = 1"
                    )

                tokens = new_tokens if (tokens is None) else np.concatenate(
                    (tokens, new_tokens), axis=2)

                # output ids
                postproc_output_ids_name = self.trtllm_output_to_postproc_input_map[
                    "output_ids"]
                postproc_input_tensors.append(
                    pb_utils.Tensor(postproc_output_ids_name, tokens))

                # sequence length
                np_seq_len_tensor = np.array([[tokens.shape[2]]],
                                             dtype=np.int32)
                postproc_seq_len_name = self.trtllm_output_to_postproc_input_map[
                    "sequence_length"]
                postproc_input_tensors.append(
                    pb_utils.Tensor(postproc_seq_len_name, np_seq_len_tensor))

        # Set input tensors from trtllm outputs
        for trtllm_output_tensor in trtllm_output_tensors:

            # output_ids and sequence_length were handled earlier
            if (self.accumulate_tokens and self.decoupled
                    and (trtllm_output_tensor.name() == "output_ids"
                         or trtllm_output_tensor.name() == "sequence_length")):
                continue

            postproc_tensor_name = self.trtllm_output_to_postproc_input_map[
                trtllm_output_tensor.name()]

            postproc_input_tensors.append(
                pb_utils.Tensor(postproc_tensor_name,
                                trtllm_output_tensor.as_numpy()))

        return tokens, postproc_input_tensors

    def _get_bls_output_tensors(self, postproc_output_tensors):

        bls_output_tensors = []

        # Set input tensors from trtllm outputs
        for postproc_output_tensor in postproc_output_tensors:

            bls_tensor_name = self.postproc_output_to_bls_output_map[
                postproc_output_tensor.name()]
            bls_output_tensors.append(
                pb_utils.Tensor(bls_tensor_name,
                                postproc_output_tensor.as_numpy()))

        return bls_output_tensors

    def _check_repetition(self, input_ids, window_size=10):
        """ Check repetition through input_ids, if the length of input_ids is less or equal than the window_size, 
            return False which means no repetition occurs.

            input_ids: np.array
            window_size: int
        """
        if input_ids.shape[2] <= window_size:
            return False

        if len(self.repetition_memory) > 100: self.repetition_memory.popitem(last=False)
        if hash(input_ids[0, 0, -window_size:].data.tobytes()) not in self.repetition_memory:
            self.repetition_memory[hash(input_ids[0, 0, -window_size:].data.tobytes())] = True
            return False
        else:
            self.logger.log_info(
                f"Found repetition, current window hash {hash(input_ids[0, 0, -window_size:].data.tobytes())}"
            )
            return True

    def execute(self, requests):

        responses = []
        bls_response_sender = None

        for request in requests:

            #Get the response sender for the BLS
            if self.decoupled:
                bls_response_sender = request.get_response_sender()

            try:
                # Get the bls input tensors
                bls_input_tensors_map = self._get_bls_input_tensors_map(
                    request)

                #Check the batch dimension
                for name, tensor in bls_input_tensors_map.items():
                    batch_dim = tensor.as_numpy().shape[0]

                    if batch_dim != 1:

                        err_str = "Inflight batching backend expects requests with batch size of 1."
                        self.logger.log_error(err_str)
                        raise pb_utils.TritonModelException(err_str)

                # Create the preprocessor input tensors
                preproc_input_tensors = self._get_preproc_input_tensors(
                    bls_input_tensors_map)

                preproc_request = pb_utils.InferenceRequest(
                    model_name="preprocessing",
                    inputs=preproc_input_tensors,
                    requested_output_names=list(
                        self.preproc_output_to_trtllm_input_map.keys()))

                #Execute preprocessor
                preproc_response = preproc_request.exec()

                if preproc_response.has_error():
                    raise pb_utils.TritonModelException(
                        preproc_response.error().message())

                # Create the trtllm input tensors
                trtllm_input_tensors = self._get_trtllm_input_tensors(
                    bls_input_tensors_map, preproc_response.output_tensors())

                trtllm_request = pb_utils.InferenceRequest(
                    model_name="tensorrt_llm",
                    inputs=trtllm_input_tensors,
                    requested_output_names=list(
                        self.trtllm_output_to_postproc_input_map.keys()))

                #Execute trtllm
                trtllm_responses = trtllm_request.exec(
                    decoupled=self.decoupled)

                if not self.decoupled:
                    trtllm_responses = [trtllm_responses]

                tokens = None

                repetition_count = 0
                max_repetition_count = 3
                #Loop over the trtllm responses
                for trtllm_response in trtllm_responses:

                    if trtllm_response.has_error():
                        raise pb_utils.TritonModelException(
                            trtllm_response.error().message())

                    trtllm_output_tensors = trtllm_response.output_tensors()

                    tokens, postproc_input_tensors = self._get_postproc_input_tensors(
                        tokens, trtllm_output_tensors)

                    if tokens[0, 0, -1] > 31999 or tokens[0, 0, -1] < 0: 
                        tokens[0, 0, -1] = 2

                    repeated = self._check_repetition(tokens)

                    postproc_request = pb_utils.InferenceRequest(
                        model_name="postprocessing",
                        inputs=postproc_input_tensors,
                        requested_output_names=list(
                            self.postproc_output_to_bls_output_map.keys()))

                    #Execute postprocessor
                    postproc_response = postproc_request.exec()

                    if postproc_response.has_error():
                        raise pb_utils.TritonModelException(
                            postproc_response.error().message())

                    if repeated:
                        repetition_count += 1
                        if repetition_count == max_repetition_count:
                            # Create the BLS response
                            self.repetition_memory.clear()
                            self.logger.log_info(
                                f"Found repetition, current tokens are {tokens}"
                            )

                            stop_trtllm_tensor = [pb_utils.Tensor("stop", np.array([[True]], dtype=bool))]
                            request_output_names = ["stop"]
                            # trtllm_input_tensors.append(pb_utils.Tensor("stop", np.array([[True]], dtype=bool)))
                            # request_output_names = list(self.trtllm_output_to_postproc_input_map.keys())
                            # request_output_names.append("stop")
                            trtllm_request = pb_utils.InferenceRequest(
                                model_name="tensorrt_llm",
                                inputs=stop_trtllm_tensor,
                                requested_output_names=request_output_names,
                            )
                            _ = trtllm_request.exec()
                            bls_response_sender.send(
                                flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                            return None

                    # Create the BLS response
                    bls_output_tensors = self._get_bls_output_tensors(
                        postproc_response.output_tensors())

                    bls_response = pb_utils.InferenceResponse(
                        output_tensors=bls_output_tensors)

                    if self.decoupled:
                        bls_response_sender.send(bls_response)
                        self.logger.log_info(f"Sent a message to client")
                    else:
                        responses.append(bls_response)

                # All responses have been sent, set final flag
                if self.decoupled:
                    bls_response_sender.send(
                        flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)

            except Exception:

                self.logger.log_error(traceback.format_exc())
                # If encountering an error, send a response with err msg
                error_response = pb_utils.InferenceResponse(
                    output_tensors=[],
                    error=pb_utils.TritonError(traceback.format_exc()))

                if self.decoupled:
                    bls_response_sender.send(error_response)
                    bls_response_sender.send(
                        flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                else:
                    responses.append(error_response)

        self.repetition_memory.clear()
        if self.decoupled:
            return None
        else:
            assert len(responses) == len(requests)
            return responses

Error: Fatal Python error: PyEval_SaveThread: the function must be called with the GIL held, but the GIL is released (the current Python thread state is NULL)

Version: v0.6.1

There might be a place in execute() didn't release the python thread

pcastonguay commented 8 months ago

Hi @ZihanLiao, I haven't been able to reproduce the issue you are reporting. I used a slightly simplified version for debugging where I send a cancellation response after 10 tokens:

                #Loop over the trtllm responses
                count = 0
                for trtllm_response in trtllm_responses:

                    count = count + 1
                   ...

                    #Send a cancellation request
                    if count == 10:
                        stop_trtllm_tensor = [pb_utils.Tensor("stop", np.array([[True]], dtype=bool))]
                        request_output_names = ["stop"]
                        trtllm_request = pb_utils.InferenceRequest(
                            model_name="tensorrt_llm",
                            inputs=stop_trtllm_tensor,
                            requested_output_names=request_output_names,
                        )
                        _ = trtllm_request.exec()
                        bls_response_sender.send(
                            flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                        return None

and don't run into any errors. Here's the client output when I print the text_output tensor.

root@aiap-dt1:/app# python3 inflight_batcher_llm/client/end_to_end_grpc_client.py -p "This is a test" --model-name "tensorrt_llm_bls" --streaming
[b' of']
[b' the']
[b' power']
[b' of']
[b' the']
[b' Internet']
[b'.']
[b' It']
[b"'s"]

What version of Triton are you using? Could you try with 23.12 or 24.01?

ZihanLiao commented 8 months ago

Hi @ZihanLiao, I haven't been able to reproduce the issue you are reporting. I used a slightly simplified version for debugging where I send a cancellation response after 10 tokens:

                #Loop over the trtllm responses
                count = 0
                for trtllm_response in trtllm_responses:

                    count = count + 1
                   ...

                    #Send a cancellation request
                    if count == 10:
                        stop_trtllm_tensor = [pb_utils.Tensor("stop", np.array([[True]], dtype=bool))]
                        request_output_names = ["stop"]
                        trtllm_request = pb_utils.InferenceRequest(
                            model_name="tensorrt_llm",
                            inputs=stop_trtllm_tensor,
                            requested_output_names=request_output_names,
                        )
                        _ = trtllm_request.exec()
                        bls_response_sender.send(
                            flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
                        return None

and don't run into any errors. Here's the client output when I print the text_output tensor.

root@aiap-dt1:/app# python3 inflight_batcher_llm/client/end_to_end_grpc_client.py -p "This is a test" --model-name "tensorrt_llm_bls" --streaming
[b' of']
[b' the']
[b' power']
[b' of']
[b' the']
[b' Internet']
[b'.']
[b' It']
[b"'s"]

What version of Triton are you using? Could you try with 23.12 or 24.01?

Thanks for your reply! Indeed, the problem is hard to reproduce and might related to this which I'm not sure. This GIL break occurred from time to time when I sequentially requested from the server. I'm using triton with version 23.10. I will try the latest version. image

pcastonguay commented 8 months ago

Ok, I know that there were a few issues with the Triton python backend code fixed in Triton 23.12. See https://github.com/triton-inference-server/python_backend/commit/8b0fa4cc5daa4b1891cdc5b0b42079dbe2a60eae and https://github.com/triton-inference-server/python_backend/commit/c5f304decda609ab21a004c525436e58dd527190

Can you try with Triton 23.12 or Triton 24.01? If you can still reproduce after upgrading the Tritonv ersion, I can spend more time on this and work with the Triton team to root cause.

Thank you.