triton-inference-server / server

The Triton Inference Server provides an optimized cloud and edge inferencing solution.
https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html
BSD 3-Clause "New" or "Revised" License
7.99k stars 1.44k forks source link

Triton limitations when deploying gpt-like generative models #5354

Open piglaker opened 1 year ago

piglaker commented 1 year ago

As NLP practitioner, I'm trying to deploy a gpt2-like generative model using triton,

  1. This type of generative model usually relies on the generate method in addition to its own forward, that is, it includes for-loop and conditional judgment. Each forward output a token and concat it into the previous input as the next input.
  2. In addition, since the compatible length of a large language model usually reaches 1024, this makes the calculation overhead very large if it is forwarded once each time. Therefore, in the implementation of some frameworks such as hugging transformers, the output of the model usually brings hidden( named past_key_values here ) together as output and used as input for the next forward. This withpast method can significantly improve performance when the length is long.

In our deployment, we first tried the python backend + onnx method to achieve the first point above in the python backend, but when we realized the second point, we encountered the problem that the overhead of InferRequest was very high. We speculated that past_key_values was in the cpu ->Transporting back and forth between gpus adds overhead (in this example, past_key_values = [1, 24, 50, 256] 34 (layers) 2(keys and values) * 4KB.

And we noticed that the design of triton is separated from different models, which seems to be unfriendly to this kind of cpu gpu mixed operation. I want to know how to solve our problem?

nv-kmcgill53 commented 1 year ago

CC: @Tabrizian and @GuanLuo

For (1) I think using Triton's BLS feature would be useful here. We have a stable diffusion example which might help with ideas.

piglaker commented 1 year ago

Yeah, I used the BLS in python backend, below is the pseudocode: ` class TritonPythonModel: def initialize(self, args): self.tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-16B-mono")

def preprocess(self, requests):
    ues tokenizer convert text to ids

def topk_search(self, input_ids):
    #
    for i in range(max_iterations):
        inference_request = pb_utils.InferenceRequest(model_name, inputs=inputs, )
        inference_response = inference_request.exec() 

        some sampling strategies here : top_k, top_p ...

        cat new_generated_ids as new_inputs

    decode here

    return result

def execute(self, requests):

    inputs  = self.preprocess(requests)

    self.topk_search(inputs)

    merge to reponses

    return responses

def finalize(self):
    print('Cleaning up...')

` This satisfies my requirement 1 on the fp16 onnx model, and the performance is good. In this example, only the inputids need to be used as the input of the inference model each time.But when we tried to implement the second point, the code of topk_search changed, that is, we need to pass past key and values to the model,The pseudo code is as follows:

inference_request = pb_utils.InferenceRequest(model_name, inputs=inputs, past_key_values=past_key_values ) inference_response = inference_request.exec() And the variable size of past key value is about 30~80MB, we speculatethis caused a lot of unexpected problems and obvious performance degradation, for example, it was about 50-200 times slower than the generation of pytorch model of huggingface transformers.

Tabrizian commented 1 year ago

It could be that you are running out of cuda memory pool and that's why the inference performance is suffering. Can you share the logs for the second use-case? Does increasing the --cuda-memory-pool-byte-size resolve the issue?

piglaker commented 1 year ago

Everything in the log is normal, no error is reported, also no OOM error. We tried to set --cuda-memory-pool-byte-size to a very large size, but it didn't work. We tried to write some metrics and insert them, then we observed the Infer Request randomly explodes from 0.5s to 2~5s, or simply executes infer at a very slow speed (2~10s).