vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
23.54k stars 3.36k forks source link

[Bug]: chunked prefill scheudler uses up swap on many n>=2 requests #5578

Open toslunar opened 1 month ago

toslunar commented 1 month ago

Your current environment

The output of `python collect_env.py`

🐛 Describe the bug

Sending many n>=2 (or best_of>=2) requests fills up CPU KV cache, more often if chunked prefill is enabled.

_schedule_chunked_prefill schedules prefills even if there are swapped seq groups https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L871-L873 while _schedule_default does not https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L763-L766

To reproduce,

import vllm
print(vllm.__version__)
from vllm import LLM, SamplingParams

long_text = open(vllm.core.scheduler.__file__).read()
prompts = [f"```python\n" + long_text[i:i+1000] for i in range(10000)]

llm = LLM(
    model="facebook/opt-125m",
    enable_chunked_prefill=True,
    disable_log_stats=False,
    max_num_batched_tokens=4096,
    num_gpu_blocks_override=8192,
)

sampling_params = SamplingParams(max_tokens=1000, n=8)
llm.generate(prompts, sampling_params)

consumes CPU KV cache (Running: 39 reqs, Swapped: 129 reqs in the end)

output ``` 0.5.0.post1 /home/kataoka/venv1/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. warnings.warn( INFO 06-16 21:23:26 config.py:707] Chunked prefill is enabled (EXPERIMENTAL). INFO 06-16 21:23:26 llm_engine.py:161] Initializing an LLM engine (v0.5.0.post1) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=facebook/opt-125m) INFO 06-16 21:23:31 weight_utils.py:218] Using model weights format ['*.bin'] INFO 06-16 21:23:32 model_runner.py:160] Loading model weights took 0.2389 GB INFO 06-16 21:23:32 llm_engine.py:317] Overriding num_gpu_blocks=127899 with num_gpu_blocks_override=8192 INFO 06-16 21:23:32 gpu_executor.py:83] # GPU blocks: 8192, # CPU blocks: 7281 INFO 06-16 21:23:35 model_runner.py:889] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. INFO 06-16 21:23:35 model_runner.py:893] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage. INFO 06-16 21:23:39 model_runner.py:965] Graph capturing finished in 4 secs. Processed prompts: 0%| | 0/10000 [00:00 [rank0]: llm.generate(prompts, sampling_params) [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/utils.py", line 691, in inner [rank0]: return fn(*args, **kwargs) [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 304, in generate [rank0]: outputs = self._run_engine(use_tqdm=use_tqdm) [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 556, in _run_engine [rank0]: step_outputs = self.llm_engine.step() [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 765, in step [rank0]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 948, in schedule [rank0]: scheduler_outputs = self._schedule() [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 921, in _schedule [rank0]: return self._schedule_chunked_prefill() [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 857, in _schedule_chunked_prefill [rank0]: remaining_running, running_scheduled = self._schedule_running( [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 434, in _schedule_running [rank0]: preempted_mode = self._preempt(victim_seq_group, [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 1101, in _preempt [rank0]: self._preempt_by_swap(seq_group, blocks_to_swap_out) [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 1122, in _preempt_by_swap [rank0]: self._swap_out(seq_group, blocks_to_swap_out) [rank0]: File "/home/kataoka/venv1/lib/python3.10/site-packages/vllm/core/scheduler.py", line 1142, in _swap_out [rank0]: raise RuntimeError( [rank0]: RuntimeError: Aborted due to the lack of CPU swap space. Please increase the swap space to avoid this error. Processed prompts: 1%| | 63/10000 [00:47<2:05:35, 1.32it/s, est. speed input: 480.87 toks/s, output: 8954.46 toks/s] ```
simon-mo commented 1 month ago

@rkooo567 any possible causes?

toslunar commented 1 month ago

To make my suggestion clear,

-        # Schedule new prefills.
-        remaining_waiting, prefills = self._schedule_prefills(
-            self.waiting, budget, curr_loras, enable_chunking=True)
+        if len(remaining_swapped) == 0:
+            # Schedule new prefills.
+            remaining_waiting, prefills = self._schedule_prefills(
+                self.waiting, budget, curr_loras, enable_chunking=True)

on https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L871-L873 fixes the issue.

However, the condition if len(remaining_swapped) == 0 looks too strict and may affect performance when the most of the requests are n == best_of == 1. Something like "CPU KV cache usage < 50%" could be better.

rkooo567 commented 1 month ago

I think n>1 creates more sequences, so it is more likely to use swap/preemption (because there's higher pressure to kv cache). Checking remaining_swapped==0 makes sense to me actually. We should prioritize swapped requests over prefill anyway. (and if all swaps are scheduled, remaining swap becomes 0 anyway). @toslunar would you like to create a PR?

toslunar commented 1 month ago

Thank you @rkooo567. It makes sense.

I created a PR. The diff is slightly different than my previous comment.