Open zhisbug opened 2 years ago
Notes on possible improvements:
A small trick to boost performance: In beam search, all beam branches actually share the same prompt context, and they don't need to be shuffled at every timestep. This can reduce both the shuffle overhead and save some memory.
In order to achieve state-of-the-art serving performance on OPT/GPT, we need to develop the following features, sorted with priority.
Task 1: Align single-GPU decoding performance with FasterTransformer.
Task 1.1: generate kernel performance table
In terms of single-GPU decoding performance, on an OPT model with 2.7B parameters with
bs=1
, JAX achieves ~65 tokens/s on an A100, while FT achieves ~78 tokens/s on a V100.Many of our components when running inference is under-optimized. Several TODOs:
Task 1.2: Optimizing the masked attention computation
Our masked attention is performing unnecessary computation; At each decoding step, we attend to the full seq_len instead of the length of the previous tokens. This is due to a static compiler design. We need to:
Task 1.3: Use training-like computation on the prompt (prefix)
We should compile at least two executables: one to perform training-like forward computation for the prompts, the other for decoding. Now we are decoding even for each token in the prompt, which is slow.
Task 2: Align intra-op parallelism performance in decoding with FasterTransformer.
Assuming Task#1 is done, we then try to align the performance with FT when autosharding is turned on.
As a reference, on the 2.7B model with
bs = 1
:This might be a simple fallback in the auto sharding solver at batchsize = 1, or some more serious problems. We need to fix it and match or outperform FT in terms of the latency reduction when adding more GPUs.
Task 3: Enable full inter-op parallelism and #mb > 1
3.1 Adding some basic features
Currently, we can do device-placement-like inter-op parallelism at
#microbatches = 1
, we need to do some additional engineering development to support full inter-op parallelism and #mb > 1.This development may not improve latency but will boost throughput significantly.
3.2 Reduce ray scheduling overheads as much as possible
We have identified many ray scheduling overheads; They are not substantial at training but become critical at inference. We shall find ways to reduce these scheduling overheads as much as possible.
I have enumerated all sources of Ray overheads (as I can think of) below. We need to benchmark the severity of these overheads and see which ones are substantial in our targeted use cases (e.g., 175B, long decoding steps, long prompts, etc.), because not all of them are easy to hide.
shard_args
at each decoding step. Note in training, thisshard_args
is called only once per iterationTask 4: support beam search, output scores/hidden states, and other related text generation utilities
Currently, we only support top-p sampling with one single output sentence. We need to support
beam_size > 1
Related near-complete PR for beam search: #491One good reference is openAI's text generation interface and features
Task 5: Study and improve the current batching mechanism
The batching in serving is complicated because it requires at least two levels of batching:
Batching incoming requests
Currently, we set two thresholds:
TIMEOUT
: a time windowMAX_TOKENS
: a maximal number of tokens allowed to be processed in a "job launch".The server will listen to requests until either
TIMEOUT
is achieved or we have collected tokens reachingMAX_TOKENS
, and then send them as one job to the backend for computation.There are potentially other and better mechanism on how to batch these requests from user requests coming in stochastically; I'll post several related papers later.
Dynamically batched computation
Suppose we are given a batch of sentences to decode. Each sentence in the batch has different lengths, and different user-requested parameters (
top_p
,beam_size
,max_decoding_length
, etc.).How should we enable batched computation of these sentences?
Our current inference system status:
batch_size = 1
Some rough ideas for improvement:
batch_size = 1, 2, 4
etc.