NVIDIA / TensorRT-LLM

TensorRT-LLM provides users with an easy-to-use Python API to define Large Language Models (LLMs) and build TensorRT engines that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. TensorRT-LLM also contains components to create Python and C++ runtimes that execute those TensorRT engines.
https://nvidia.github.io/TensorRT-LLM
Apache License 2.0
8.71k stars 996 forks source link

KV cache re-use impact on average sequence latency? #2456

Open mkserge opened 3 days ago

mkserge commented 3 days ago

Hello,

I am benchmarking KV cache re-use with Mistral 7B model using tensor parallel across 8 A100 GPUs (A100-SXM4-40GB).

My instruction prompt is fixed at 1214 tokens, and maximum sequence length is 1357 tokens (input + output).

From the graph the throughput at a given latency threshold increases significantly, which seems to make sense, but I am a bit surprised at a much smaller gain in latency at lower request rate. For example, at request_rate = 1, average sequence latency goes from 115.35ms down to 94.56ms when re-using KV cache. Isn't this low considering that a very large chunk of the input prompt is cached?

Results, Image

For reference, I build the model with

trtllm-build \
  --checkpoint_dir /mistral-models/mistral-7b-instruct-v0.3-hf-trt-bf16-tp8 \
  --output_dir /mistral-models/mistral-7b-instruct-v0.3-hf-trt-engines/bf16/8-gpu-kv-cache/ \
  --gemm_plugin auto \
  --tokens_per_block 128 \
  --use_paged_context_fmha enable \
  --multiple_profiles enable \
  --max_seq_len 1357

and benchmark it using

  mpirun -n 8 --allow-run-as-root ./gptManagerBenchmark \
    --engine_dir /mistral-models/mistral-7b-instruct-v0.3-hf-trt-engines/bf16/8-gpu-kv-cache/ \
    --request_rate 1 \
    --enable_kv_cache_reuse enable \
    --dataset /mistral-runs/benchmarks/data/tokens/dataset_tokens.jsonl \
    --output_csv ./request_rate_tp_8_kv_cache_reuse_1.csv \
    --max_num_samples 100;
ttim commented 16 hours ago

@mkserge at such low batch sizes the time needed to prefill is very small hence you don’t see big difference in performance. Prefill is usually compute bound and for your setup (8xa100) theoretical flops is around 2.5pflops. Flops needed to do 1200 tokens prefill of 7B model is around 8.5 tflops. Meaning that theoretical time needed to do batch size 1 prefill is around 3.5 ms. In reality it’s more than this, but this gives a good idea on the magnitude. Therefore until your prefill tokens counts doesn’t become significant you wouldn’t see big benefit of caching prefill.