huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
8.62k stars 991 forks source link

Improve inference speed of Santacoder and Starcoder (and others) #376

Open jlamypoirier opened 1 year ago

jlamypoirier commented 1 year ago

I did some extensive investigation, testing and benchmarking, and determined that the following is needed to speedup inference for the Bigcode models (and most of text-gen-inference models:

  1. Use FlashAttentionfor prefill only. This is recommended by the authors because the FlashAttention kernel relies on a high query length to achieve good parallelization, and because FlashAttention needs a lot of extra work on the inputs/outputs/KV caches for each token.
  2. Vectorize as much pre/post-processing operations as possible, i.e. avoid loops (especially for cuda ops). The warpers / logit processors have already been vectorized in #317, and the rest of causal_lm has a prototype implementation in #272 (flash_causal_lm is harder to vectorize, but according to the point above causal_lm should be preferable.)
  3. Perform some form of KV cache pre-allocation and key length padding to a multiple of 8. A complete, static pre-allocated tensor adds complications because of the need to concatenate/filter batches, but it's easy to pre-allocate only a few tokens in advance to run the slow concatenation on every N tokens instead of all of them. (Again, this is not doable with FlashAttention.) Padding the key length to a multiple of 8 also provides a high speedup, so N=8 is a bare minimum (though higher is better.
  4. Compute the details (logprobs, prefill data, etc.) only when requested (#288). These take a lot of time and force computing the whole model head (see 5. below), but the results are almost always thrown away.
  5. Compute the model head only for the last token in prefill (unless we do need them for details). This saves some time and more importantly avoids a memory bottleneck.
  6. Use deterministic generation only when a seed is provided. Otherwise, sampling needs to be done in a loop because Pytorch doesn't support vectorized generators.
  7. Trim the python code. Avoid any unnecessary function call (use inline when possible), attribute getting, etc., as these end up contributing a lot to the CPU latency. Avoid subclassing nn.Module because it adds a lot of bloat (hooks) on __call__ and getattr. In tests I was able to reduce the santacoder min latency by more than 20% in this way.

Future work (more investigation needed):

  1. Try and compare more fused kernels. For fused softmax compare Jit (used in #272) and Megatron's implementation (probably better). Compare fused and standard layer norm (results below seem to go against fused). Try fused dense (with gelu) in MLP (or try Jit?)
  2. Reduce memory allocations by pre-allocating and/or reusing tensors. The main obstacle is that many operations still don't support the out argument, so some (easy) cpp work would be needed.
  3. Write the cpu-intensive part (Block) in cpp. This would not be too hard and would help a lot with the latency for smaller models, but may not be needed if cuda graphs are used.
  4. Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch concatenate / filter due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on every filter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway.
  5. Look more into tensor parallelism. I know it's already implemented in text-gen-inference, but I haven't looked into it myself.
jlamypoirier commented 1 year ago

Some benchmarking results, comparing several implementations:

  1. flash: flash_santacoder, the current implementation.
  2. causal: The gpt_bigcode model from HF transformers, run with causal_lm.
  3. vector: The gpt_bigcode model from HF transformers, run with vectorized_causal_lm from #272. (Opt. 2 above).
  4. bigcode: The gpt_bigcode model from the Bigcode transformers repo, with minor adaptations and trimming to work with text-gen-inference and vectorized_causal_lm (Opt. 1, 3, 4, 5, 6)
  5. bigcode2: bigcode with some additional optimizations taken from flash_santacoder, mainly the FastLinear and FastLayerNorm layers. Also some simplifications on the attention mask.
  6. bigcode3: bigcode2 with a trimmed python code (Opt. 7)

Note: flash and causal are based on commit 5a58226 (May 16th) so may be missing the latest optimizations. Also note: curves are smoothed out, otherwise they oscillate wildly without key length padding (causal and vector)

Santacoder decode

santacoder_bs_1_tok_2040_decode_step_5_10 santacoder_bs_32_tok_2040_decode_step_5_10 santacoder_bs_256_tok_2040_decode_step_5_10

Santacoder prefill

santacoder_bs_1_tok_2040_prefill_step_11_10 santacoder_bs_32_tok_2040_prefill_step_11_10 santacoder_bs_256_tok_2040_prefill_step_11_10

jlamypoirier commented 1 year ago

Starcoder decode

starcoder_bs_1_tok_8190_decode_step_11_10 starcoder_bs_32_tok_8190_decode_step_11_10 starcoder_bs_256_tok_8190_decode_step_11_10

Starcoder prefill

starcoder_bs_1_tok_8190_prefill_step_29_1 starcoder_bs_32_tok_8190_prefill_step_29_1

huyphan168 commented 1 year ago

@jlamypoirier Thanks for great investigation. """Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch concatenate / filter due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on every filter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway."""

Can you show me where did you implement the cuda graphs with dynamic size for SantaCoder? I wonder how it is implemented.

jlamypoirier commented 1 year ago

@jlamypoirier Thanks for great investigation. """Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch concatenate / filter due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on every filter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway."""

Can you show me where did you implement the cuda graphs with dynamic size for SantaCoder? I wonder how it is implemented.

Sorry for the late response, you can find my (messy) implementation in https://github.com/bigcode-project/transformers/blob/main/src/transformers/models/gpt_bigcode/inference_runner.py. Note that this version supports dynamic key lengths but not dynamic batch sizes.

aliswel-mt commented 1 year ago

@jlamypoirier Amazing reports !! May I ask does sequence length indicate max_new_token? I got pretty high latency (about 4s) for starcoder when I set max_new_token=128

jlamypoirier commented 1 year ago

@jlamypoirier Amazing reports !! May I ask does sequence length indicate max_new_token? I got pretty high latency (about 4s) for starcoder when I set max_new_token=128

It's the time to generate one token. For full time you need to add prefill for context length and generate for range(context_length, context_length + max_new_tokens)

truenorth8 commented 10 months ago

@jlamypoirier These are great suggestions. Have any of these found their way upstream? If not, is your version available anywhere?

edit: especially curious about

Compute the model head only for the last token in prefill (unless we do need them for details). This saves some time and more importantly avoids a memory bottleneck.

github-actions[bot] commented 1 week ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.