vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.68k stars 3.91k forks source link

[Bug]: Available context (gpu blocks) gets halved by pipeline parallel size #7039

Open charai-frontend opened 1 month ago

charai-frontend commented 1 month ago

Your current environment

vLLM main branch commit c8a7e932

šŸ› Describe the bug

Using --pipeline_parallel_size=2, it will throw an error if the prompt uses more than half the available tokens.

vLLM reports a capacity of 1650 blocks / 26.4k tokens when loading the model, --max_model_len was set to 24000:

INFO 08-01 15:24:16 distributed_gpu_executor.py:56] # GPU blocks: 1650, # CPU blocks: 0

But when sending any prompt with >13k tokens, it throws an input prompt is too long error:

WARNING 08-01 15:24:58 scheduler.py:706] Input prompt (512 tokens) is too long and exceeds the capacity of block_manager

Adding these print statements:

diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py
index e29eba37..3bf57230 100644
--- a/vllm/core/block_manager_v1.py
+++ b/vllm/core/block_manager_v1.py
@@ -224,6 +224,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
     ) -> None:
         self.block_size = block_size
         self.num_total_gpu_blocks = num_gpu_blocks
+        logger.info(f"self.num_total_gpu_blocks = {num_gpu_blocks}")
         self.num_total_cpu_blocks = num_cpu_blocks

         if enable_caching and sliding_window is not None:
@@ -286,6 +287,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
         # Use watermark to avoid frequent cache eviction.
         if (self.num_total_gpu_blocks - num_required_blocks <
                 self.watermark_blocks):
+            logger.info(f"total_gpu_blocks({self.num_total_gpu_blocks}) - num_required_blocks({num_required_blocks}) < watermark_blocks({self.watermark_blocks})")
             return AllocStatus.NEVER
         if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
             return AllocStatus.OK

Shows that the check is done against 825 blocks (13.2k tokens), which is half of the reported capacity at the start, and this request fails with a 21.8k prompt:

INFO 08-01 15:36:54 distributed_gpu_executor.py:56] # GPU blocks: 1650, # CPU blocks: 0
[...]
INFO 08-01 15:37:31 model_runner.py:1219] Graph capturing finished in 37 secs.
INFO 08-01 15:37:31 block_manager_v1.py:227] self.num_total_gpu_blocks = 825
INFO 08-01 15:37:31 block_manager_v1.py:227] self.num_total_gpu_blocks = 825
[...]
INFO:     127.0.0.1:38182 - "POST /v1/chat/completions HTTP/1.1" 200 OK
INFO 08-01 15:38:17 block_manager_v1.py:290] total_gpu_blocks(825) - num_required_blocks(1363) < watermark_blocks(8)
WARNING 08-01 15:38:17 scheduler.py:706] Input prompt (512 tokens) is too long and exceeds the capacity of block_manager
youkaichao commented 1 month ago

cc @andoorve

andoorve commented 1 month ago

This is kind of expected behavior based on what our implementation of PP aims to do. We report 1650 blocks because this is the total number of blocks available on your GPU. However, this gets divided into 2 KV cache sections to support multiple request streams at the same time. This is what allows us to have pipelining of request streams at the same time with load balancing of the two streams. This reporting could possibly be improved to reflect this.

The situation you are talking about (very long prompt) is possible with some changes. You can submit a feature request for the same if there's a very clear use case for it. However, it wasn't a main focus until now since basically only 1 of those very long prompts could be resident in the cache at a time. This would mean essentially no pipelining, and you might want to see if tensor parallelism serves your use case better.