alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.07k stars 356 forks source link

[Roadmap] Inference performance (for OPT/GPT) #534

Open zhisbug opened 2 years ago

zhisbug commented 2 years ago

In order to achieve state-of-the-art serving performance on OPT/GPT, we need to develop the following features, sorted with priority.

Task 1: Align single-GPU decoding performance with FasterTransformer.

Task 1.1: generate kernel performance table

In terms of single-GPU decoding performance, on an OPT model with 2.7B parameters with bs=1, JAX achieves ~65 tokens/s on an A100, while FT achieves ~78 tokens/s on a V100.

Many of our components when running inference is under-optimized. Several TODOs:

Task 1.2: Optimizing the masked attention computation

Our masked attention is performing unnecessary computation; At each decoding step, we attend to the full seq_len instead of the length of the previous tokens. This is due to a static compiler design. We need to:

Task 1.3: Use training-like computation on the prompt (prefix)

We should compile at least two executables: one to perform training-like forward computation for the prompts, the other for decoding. Now we are decoding even for each token in the prompt, which is slow.

Task 2: Align intra-op parallelism performance in decoding with FasterTransformer.

Assuming Task#1 is done, we then try to align the performance with FT when autosharding is turned on.

As a reference, on the 2.7B model with bs = 1:

This might be a simple fallback in the auto sharding solver at batchsize = 1, or some more serious problems. We need to fix it and match or outperform FT in terms of the latency reduction when adding more GPUs.

Task 3: Enable full inter-op parallelism and #mb > 1

3.1 Adding some basic features

Currently, we can do device-placement-like inter-op parallelism at #microbatches = 1, we need to do some additional engineering development to support full inter-op parallelism and #mb > 1.

This development may not improve latency but will boost throughput significantly.

3.2 Reduce ray scheduling overheads as much as possible

We have identified many ray scheduling overheads; They are not substantial at training but become critical at inference. We shall find ways to reduce these scheduling overheads as much as possible.

I have enumerated all sources of Ray overheads (as I can think of) below. We need to benchmark the severity of these overheads and see which ones are substantial in our targeted use cases (e.g., 175B, long decoding steps, long prompts, etc.), because not all of them are easy to hide.

Task 4: support beam search, output scores/hidden states, and other related text generation utilities

Currently, we only support top-p sampling with one single output sentence. We need to support beam_size > 1 Related near-complete PR for beam search: #491

One good reference is openAI's text generation interface and features

Task 5: Study and improve the current batching mechanism

The batching in serving is complicated because it requires at least two levels of batching:

Batching incoming requests

Currently, we set two thresholds:

  1. TIMEOUT: a time window
  2. MAX_TOKENS: a maximal number of tokens allowed to be processed in a "job launch".

The server will listen to requests until either TIMEOUT is achieved or we have collected tokens reaching MAX_TOKENS, and then send them as one job to the backend for computation.

There are potentially other and better mechanism on how to batch these requests from user requests coming in stochastically; I'll post several related papers later.

Dynamically batched computation

Suppose we are given a batch of sentences to decode. Each sentence in the batch has different lengths, and different user-requested parameters (top_p, beam_size, max_decoding_length, etc.).

How should we enable batched computation of these sentences?

Our current inference system status:

Some rough ideas for improvement:

merrymercy commented 2 years ago

Notes on possible improvements:

zhuohan123 commented 2 years ago

A small trick to boost performance: In beam search, all beam branches actually share the same prompt context, and they don't need to be shuffled at every timestep. This can reduce both the shuffle overhead and save some memory.