triton-inference-server / tensorrtllm_backend

The Triton TensorRT-LLM Backend
Apache License 2.0
694 stars 103 forks source link

Failed to call a service after launching the server. #228

Open THU-mjx opened 10 months ago

THU-mjx commented 10 months ago
curl -X POST localhost:8000/v2/models/ensemble/generate -d '{"text_input": "What is machine learning?", "max_tokens": 20, "bad_words": "", "stop_words": ""}'
{"error":"in ensemble 'ensemble', [request id: <id_unknown>] input 'input_ids' batch size does not match other inputs for 'tensorrt_llm'"}

We are trying to implement end-to-end deployment on a recently reported model. The engine was successfully built using TensorRT-LLM-0.5.0 release and the server was also launched successfully. When trying to call a service with curl, an error occurred.

I just want to know where the error "error":"in ensemble 'ensemble' come from, and where the localhost:8000/v2/models/ensemble/generate function come from.

kaiyux commented 10 months ago

ensemble is an idea of tritonserver to link several model together. From the log, the input_ids shape of ensemble and tensorrt_llm do not match. You could check their config.

THU-mjx commented 10 months ago

here are the configs: preprocessing/config.pbtxt:

name: "preprocessing"
backend: "python"
max_batch_size: 1
input [
    {
        name: "QUERY"
        data_type: TYPE_STRING
        dims: [ -1 ]
    },
    {
        name: "BAD_WORDS_DICT"
        data_type: TYPE_STRING
        dims: [ -1 ]
    },
    {
        name: "STOP_WORDS_DICT"
        data_type: TYPE_STRING
        dims: [ -1 ]
    },
    {
        name: "REQUEST_OUTPUT_LEN"
        data_type: TYPE_UINT32
        dims: [ -1 ]
    }
]
output [
    {
        name: "INPUT_ID"
        data_type: TYPE_INT32
        dims: [ -1 ]
    },
    {
        name: "LFA_IDX"
        data_type: TYPE_INT32
        dims: [ -1 ]
    },
    {
        name: "REQUEST_INPUT_LEN"
        data_type: TYPE_INT32
        dims: [ 1 ]
    },
    {
        name: "BAD_WORDS_IDS"
        data_type: TYPE_INT32
        dims: [ 2, -1 ]
    },
    {
        name: "STOP_WORDS_IDS"
        data_type: TYPE_INT32
        dims: [ 2, -1 ]
    },
    {
        name: "REQUEST_OUTPUT_LEN"
        data_type: TYPE_UINT32
        dims: [ -1 ]
    }
]

parameters {
  key: "tokenizer_dir"
  value: {
    string_value: "/temp_data/LLM_test/yuan_model/hf_tokenizer"
  }
}

parameters {
  key: "tokenizer_type"
  value: {
    string_value: "llama"
  }
}

instance_group [
    {
        count: 1
        kind: KIND_CPU
    }
]

postprocessing/config.pbtxt:

name: "postprocessing"
backend: "python"
max_batch_size: 1
input [
  {
    name: "TOKENS_BATCH"
    data_type: TYPE_INT32
    dims: [ -1, -1 ]
  }
]
output [
  {
    name: "OUTPUT"
    data_type: TYPE_STRING
    dims: [ -1, -1 ]
  }
]

parameters {
  key: "tokenizer_dir"
  value: {
    string_value: "/temp_data/LLM_test/yuan_model/hf_tokenizer"
  }
}

parameters {
  key: "tokenizer_type"
  value: {
    string_value: "llama"
  }
}

instance_group [
    {
        count: 1
        kind: KIND_CPU
    }
]

tensorrt_llm/config.pbtxt:

backend: "tensorrtllm"
max_batch_size: 1

model_transaction_policy {
  decoupled: true
}

input [
  {
    name: "input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
  },
  {
    name: "lfa_idx"
    data_type: TYPE_INT32
    dims: [ -1 ]
  },
  {
    name: "input_lengths"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
  },
  {
    name: "request_output_len"
    data_type: TYPE_UINT32
    dims: [ 1 ]
  },
  {
    name: "end_id"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "pad_id"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "beam_width"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "temperature"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_k"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "len_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "repetition_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "min_length"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "presence_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "random_seed"
    data_type: TYPE_UINT64
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "stop"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  },
  {
    name: "streaming"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  }
]
output [
  {
    name: "output_ids"
    data_type: TYPE_INT32
    dims: [ -1, -1 ]
  }
]
instance_group [
  {
    count: 1
    kind : KIND_CPU
  }
]
parameters: {
  key: "max_beam_width"
  value: {
    string_value: "1"
  }
}
parameters: {
  key: "FORCE_CPU_ONLY_INPUT_TENSORS"
  value: {
    string_value: "no"
  }
}
parameters: {
  key: "gpt_model_type"
  value: {
    string_value: "V1"
  }
}
parameters: {
  key: "gpt_model_path"
  value: {
    string_value: "/temp_data/LLM_test/yuan_model/trt_engines/fp16/yuan_kv_cache"
  }
}
parameters: {
  key: "max_tokens_in_paged_kv_cache"
  value: {
    string_value: "${max_tokens_in_paged_kv_cache}"
  }
}
parameters: {
  key: "batch_scheduler_policy"
  value: {
    string_value: "${batch_scheduler_policy}"
  }
}
parameters: {
  key: "kv_cache_free_gpu_mem_fraction"
  value: {
    string_value: "${kv_cache_free_gpu_mem_fraction}"
  }
}
parameters: {
  key: "max_num_sequences"
  value: {
    string_value: "${max_num_sequences}"
  }
}
parameters: {
  key: "enable_trt_overlap"
  value: {
    string_value: "${enable_trt_overlap}"
  }
}

ensemble/config.pbtxt:

name: "ensemble"
platform: "ensemble"
max_batch_size: 1
input [
  {
    name: "text_input"
    data_type: TYPE_STRING
    dims: [ -1 ]
  },
  {
    name: "max_tokens"
    data_type: TYPE_UINT32
    dims: [ -1 ]
  },
  {
   name: "bad_words"
   data_type: TYPE_STRING
   dims: [ -1 ]
  },
  {
   name: "stop_words"
   data_type: TYPE_STRING
   dims: [ -1 ]
  },
  {
    name: "end_id"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "pad_id"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "top_k"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "top_p"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "temperature"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "length_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "repetition_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "min_length"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "presence_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "random_seed"
    data_type: TYPE_UINT64
    dims: [ 1 ]
    optional: true
  },
  {
    name: "beam_width"
    data_type: TYPE_UINT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "stream"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  }
]
output [
  {
    name: "text_output"
    data_type: TYPE_STRING
    dims: [ -1, -1 ]
  }
]
ensemble_scheduling {
  step [
    {
      model_name: "preprocessing"
      model_version: -1
      input_map {
        key: "QUERY"
        value: "text_input"
      }
      input_map {
        key: "REQUEST_OUTPUT_LEN"
        value: "max_tokens"
      }
      input_map {
        key: "BAD_WORDS_DICT"
        value: "bad_words"
      }
      input_map {
        key: "STOP_WORDS_DICT"
        value: "stop_words"
      }
      output_map {
        key: "REQUEST_INPUT_LEN"
        value: "_REQUEST_INPUT_LEN"
      }
      output_map {
        key: "INPUT_ID"
        value: "_INPUT_ID"
      }
      output_map {
        key: "LFA_IDX"
        value: "_LFA_IDX"
      }
      output_map {
        key: "REQUEST_OUTPUT_LEN"
        value: "_REQUEST_OUTPUT_LEN"
      }
    },
    {
      model_name: "tensorrt_llm"
      model_version: -1
      input_map {
        key: "input_ids"
        value: "_INPUT_ID"
      }
      input_map {
        key: "lfa_idx"
        value: "_LFA_IDX"
      }
      input_map {
        key: "input_lengths"
        value: "_REQUEST_INPUT_LEN"
      }
      input_map {
        key: "request_output_len"
        value: "_REQUEST_OUTPUT_LEN"
      }
      input_map {
          key: "end_id"
          value: "end_id"
      }
      input_map {
          key: "pad_id"
          value: "pad_id"
      }
      input_map {
          key: "runtime_top_k"
          value: "top_k"
      }
      input_map {
          key: "runtime_top_p"
          value: "top_p"
      }
      input_map {
          key: "temperature"
          value: "temperature"
      }
      input_map {
          key: "len_penalty"
          value: "length_penalty"
      }
      input_map {
          key: "repetition_penalty"
          value: "repetition_penalty"
      }
      input_map {
          key: "min_length"
          value: "min_length"
      }
      input_map {
          key: "presence_penalty"
          value: "presence_penalty"
      }
      input_map {
          key: "random_seed"
          value: "random_seed"
      }
      input_map {
          key: "beam_width"
          value: "beam_width"
      }
      input_map {
          key: "streaming"
          value: "stream"
      }
      output_map {
        key: "output_ids"
        value: "_TOKENS_BATCH"
      }
    },
      model_name: "postprocessing"
      model_version: -1
      input_map {
        key: "TOKENS_BATCH"
        value: "_TOKENS_BATCH"
      }
      output_map {
        key: "OUTPUT"
        value: "text_output"
      }
    }
  ]
}
THU-mjx commented 10 months ago

And here is preprocessing/model.py:

from typing import List

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer

class TritonPythonModel:
    """Your Python model must use the same class name. Every Python model
    that is created must have "TritonPythonModel" as the class name.
    """

    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.
        Implementing `initialize` function is optional. This function allows
        the model to initialize any state associated with this model.
        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """
        # Parse model configs
        model_config = json.loads(args['model_config'])
        tokenizer_dir = model_config['parameters']['tokenizer_dir'][
            'string_value']
        tokenizer_type = model_config['parameters']['tokenizer_type'][
            'string_value']

        if tokenizer_type == 't5':
            self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir,
                                         padding_side='left')
        elif tokenizer_type == 'auto':
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir,
                                                           padding_side='left', trust_remote_code=True)
        elif tokenizer_type == 'llama':
            self.tokenizer = LlamaTokenizer.from_pretrained(
                tokenizer_dir,trust_remote_code=True, legacy=False, padding_side='left')
            self.tokenizer.add_tokens(['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>','<commit_before>','<commit_msg>','<commit_after>','<jupyter_start>','<jupyter_text>','<jupyter_code>','<jupyter_output>','<empty_output>'], special_tokens=True)
        #elif tokenizer_type == 'yuan':
        #    self.tokenizer = LlamaTokenizer.from_pretrained(
        #        tokenizer_dir,trust_remote_code=True, legacy=False, padding_side='left')
        else:
            raise AttributeError(
                f'Unexpected tokenizer type: {tokenizer_type}')
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.pad_id = self.tokenizer.encode(self.tokenizer.pad_token,
                                            add_special_tokens=False)[0]

        # Parse model output configs and convert Triton types to numpy types
        input_names = [
            "INPUT_ID", "LFA_IDX", "REQUEST_INPUT_LEN", "BAD_WORDS_IDS", "STOP_WORDS_IDS",
        ]
        for input_name in input_names:
            setattr(
                self,
                input_name.lower() + "_dtype",
                pb_utils.triton_string_to_numpy(
                    pb_utils.get_output_config_by_name(
                        model_config, input_name)['data_type']))
    def execute(self, requests):
        """`execute` must be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference is requested
        for this model. Depending on the batching configuration (e.g. Dynamic
        Batching) used, `requests` may contain multiple requests. Every
        Python model, must create one pb_utils.InferenceResponse for every
        pb_utils.InferenceRequest in `requests`. If there is an error, you can
        set the error argument when creating a pb_utils.InferenceResponse.
        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest
        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """

        responses = []

        # Every Python backend must iterate over everyone of the requests
        # and create a pb_utils.InferenceResponse for each of them.
        for idx, request in enumerate(requests):
            # Get input tensors
            query = pb_utils.get_input_tensor_by_name(request,
                                                      'QUERY').as_numpy()
            request_output_len = pb_utils.get_input_tensor_by_name(
                request, 'REQUEST_OUTPUT_LEN').as_numpy()

            bad_words_dict = pb_utils.get_input_tensor_by_name(
                request, 'BAD_WORDS_DICT').as_numpy()
            stop_words_dict = pb_utils.get_input_tensor_by_name(
                request, 'STOP_WORDS_DICT').as_numpy()

            # Preprocessing input data.
            input_id, lfa_idx, request_input_len = self._create_request(query)

            bad_words = self._to_word_list_format(bad_words_dict)
            stop_words = self._to_word_list_format(stop_words_dict)

            # Create output tensors. You need pb_utils.Tensor
            # objects to create pb_utils.InferenceResponse.
            input_id_tensor = pb_utils.Tensor(
                'INPUT_ID',
                np.array(input_id).astype(self.input_id_dtype))
            lfa_idx_tensor = pb_utils.Tensor(
                'LFA_IDX',
                np.array(lfa_idx).astype(self.input_id_dtype))
            request_input_len_tensor = pb_utils.Tensor(
                'REQUEST_INPUT_LEN',
                np.array(request_input_len).astype(
                    self.request_input_len_dtype))
            request_output_len_tensor = pb_utils.Tensor(
                'REQUEST_OUTPUT_LEN', request_output_len)
            bad_words_ids_tensor = pb_utils.Tensor('BAD_WORDS_IDS', bad_words)
            stop_words_ids_tensor = pb_utils.Tensor('STOP_WORDS_IDS',
                                                    stop_words)

            # Create InferenceResponse. You can set an error here in case
            # there was a problem with handling this inference request.
            # Below is an example of how you can set errors in inference
            # response:
            #
            # pb_utils.InferenceResponse(
            #    output_tensors=..., TritonError("An error occurred"))
            inference_response = pb_utils.InferenceResponse(output_tensors=[
                input_id_tensor, lfa_idx_tensor, bad_words_ids_tensor, stop_words_ids_tensor,
                request_input_len_tensor, request_output_len_tensor
            ])
            responses.append(inference_response)

        # You should return a list of pb_utils.InferenceResponse. Length
        # of this list must match the length of `requests` list.
        return responses
    def finalize(self):
        """`finalize` is called only once when the model is being unloaded.
        Implementing `finalize` function is optional. This function allows
        the model to perform any necessary clean ups before exit.
        """
        print('Cleaning up...')

    def _create_request(self, query):
        """
            query : batch string (2D numpy array)
        """
        start_ids = [
            torch.IntTensor(self.tokenizer.encode(s[0].decode()))
            for s in query
        ]
        start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids])

        start_ids = pad_sequence(start_ids,
                                 batch_first=True,
                                 padding_value=self.pad_id)
        # input_len = min(start_lengths)
        #attn_mask = torch.ones((batch_size, input_len, input_len)).tril()
        lfa_idx = torch.IntTensor(range(start_ids.shape[1])) #[0,1,...len(start_ids)]
        return start_ids, lfa_idx, start_lengths

    def _to_word_list_format(self, word_dict: List[List[str]]):
        '''
        format of word_dict
            len(word_dict) should be same to batch_size
            word_dict[i] means the words for batch i
            len(word_dict[i]) must be 1, which means it only contains 1 string
            This string can contains several sentences and split by ",".
            For example, if word_dict[2] = " I am happy, I am sad", then this function will return
            the ids for two short sentences " I am happy" and " I am sad".
        '''
        assert self.tokenizer != None, "need to set tokenizer"

        flat_ids = []
        offsets = []
        for word_dict_item in word_dict:
            item_flat_ids = []
            item_offsets = []

            if isinstance(word_dict_item[0], bytes):
                word_dict_item = [word_dict_item[0].decode()]

            words = list(csv.reader(word_dict_item))[0]
            for word in words:
                ids = self.tokenizer.encode(word)

                if len(ids) == 0:
                    continue

                item_flat_ids += ids
                item_offsets.append(len(ids))

            flat_ids.append(np.array(item_flat_ids))
            offsets.append(np.cumsum(np.array(item_offsets)))

        pad_to = max(1, max(len(ids) for ids in flat_ids))

        for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
            flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)),
                                 constant_values=0)
            offsets[i] = np.pad(offs, (0, pad_to - len(offs)),
                                constant_values=-1)

        return np.array([flat_ids, offsets], dtype="int32").transpose(
            (1, 0, 2))