Closed toslunar closed 2 days ago
@rkooo567 any possible causes?
To make my suggestion clear,
- # Schedule new prefills.
- remaining_waiting, prefills = self._schedule_prefills(
- self.waiting, budget, curr_loras, enable_chunking=True)
+ if len(remaining_swapped) == 0:
+ # Schedule new prefills.
+ remaining_waiting, prefills = self._schedule_prefills(
+ self.waiting, budget, curr_loras, enable_chunking=True)
on https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L871-L873 fixes the issue.
However, the condition if len(remaining_swapped) == 0
looks too strict and may affect performance when the most of the requests are n == best_of == 1
. Something like "CPU KV cache usage < 50%" could be better.
I think n>1 creates more sequences, so it is more likely to use swap/preemption (because there's higher pressure to kv cache). Checking remaining_swapped==0 makes sense to me actually. We should prioritize swapped requests over prefill anyway. (and if all swaps are scheduled, remaining swap becomes 0 anyway). @toslunar would you like to create a PR?
Thank you @rkooo567. It makes sense.
I created a PR. The diff is slightly different than my previous comment.
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!
Your current environment
🐛 Describe the bug
Sending many
n>=2
(orbest_of>=2
) requests fills up CPU KV cache, more often if chunked prefill is enabled._schedule_chunked_prefill
schedules prefills even if there are swapped seq groups https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L871-L873 while_schedule_default
does not https://github.com/vllm-project/vllm/blob/v0.5.0.post1/vllm/core/scheduler.py#L763-L766To reproduce,
consumes CPU KV cache (
Running: 39 reqs, Swapped: 129 reqs
in the end)output
``` 0.5.0.post1 /home/kataoka/venv1/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. warnings.warn( INFO 06-16 21:23:26 config.py:707] Chunked prefill is enabled (EXPERIMENTAL). INFO 06-16 21:23:26 llm_engine.py:161] Initializing an LLM engine (v0.5.0.post1) with config: model='facebook/opt-125m', speculative_config=None, tokenizer='facebook/opt-125m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=facebook/opt-125m) INFO 06-16 21:23:31 weight_utils.py:218] Using model weights format ['*.bin'] INFO 06-16 21:23:32 model_runner.py:160] Loading model weights took 0.2389 GB INFO 06-16 21:23:32 llm_engine.py:317] Overriding num_gpu_blocks=127899 with num_gpu_blocks_override=8192 INFO 06-16 21:23:32 gpu_executor.py:83] # GPU blocks: 8192, # CPU blocks: 7281 INFO 06-16 21:23:35 model_runner.py:889] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. INFO 06-16 21:23:35 model_runner.py:893] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage. INFO 06-16 21:23:39 model_runner.py:965] Graph capturing finished in 4 secs. Processed prompts: 0%| | 0/10000 [00:00, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]INFO 06-16 21:23:45 metrics.py:341] Avg prompt throughput: 707.6 tokens/s, Avg generation throughput: 15.2 tokens/s, Running: 12 reqs, Swapped: 0 reqs, Pending: 9988 reqs, GPU KV cache usage: 3.4%, CPU KV cache usage: 0.0%. INFO 06-16 21:23:50 metrics.py:341] Avg prompt throughput: 1640.1 tokens/s, Avg generation throughput: 15576.1 tokens/s, Running: 35 reqs, Swapped: 0 reqs, Pending: 9965 reqs, GPU KV cache usage: 68.1%, CPU KV cache usage: 0.0%. WARNING 06-16 21:23:53 scheduler.py:1089] Sequence group 37 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=1 INFO 06-16 21:23:55 metrics.py:341] Avg prompt throughput: 1913.5 tokens/s, Avg generation throughput: 15094.3 tokens/s, Running: 33 reqs, Swapped: 29 reqs, Pending: 9938 reqs, GPU KV cache usage: 99.6%, CPU KV cache usage: 25.7%. WARNING 06-16 21:23:57 scheduler.py:1089] Sequence group 75 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=51 INFO 06-16 21:24:00 metrics.py:341] Avg prompt throughput: 3919.1 tokens/s, Avg generation throughput: 9969.3 tokens/s, Running: 30 reqs, Swapped: 87 reqs, Pending: 9883 reqs, GPU KV cache usage: 99.5%, CPU KV cache usage: 75.7%. WARNING 06-16 21:24:01 scheduler.py:1089] Sequence group 123 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=101 Processed prompts: 0%| | 24/10000 [00:18<1:18:53, 2.11it/s, est. speed input: 463.46 toks/s, output: 8753.30 toks/s]INFO 06-16 21:24:05 metrics.py:341] Avg prompt throughput: 929.1 tokens/s, Avg generation throughput: 11742.2 tokens/s, Running: 36 reqs, Swapped: 69 reqs, Pending: 9870 reqs, GPU KV cache usage: 68.0%, CPU KV cache usage: 36.8%. Processed prompts: 0%| | 32/10000 [00:24<1:34:23, 1.76it/s, est. speed input: 465.61 toks/s, output: 8705.86 toks/s]INFO 06-16 21:24:10 metrics.py:341] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 14147.8 tokens/s, Running: 37 reqs, Swapped: 61 reqs, Pending: 9870 reqs, GPU KV cache usage: 88.3%, CPU KV cache usage: 33.1%. Processed prompts: 0%| | 34/10000 [00:29<3:36:54, 1.31s/it, est. speed input: 407.31 toks/s, output: 7561.83 toks/s]INFO 06-16 21:24:15 metrics.py:341] Avg prompt throughput: 2706.1 tokens/s, Avg generation throughput: 13395.4 tokens/s, Running: 34 reqs, Swapped: 100 reqs, Pending: 9832 reqs, GPU KV cache usage: 99.2%, CPU KV cache usage: 70.4%. WARNING 06-16 21:24:15 scheduler.py:1089] Sequence group 165 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=151 Processed prompts: 0%| | 47/10000 [00:35<42:17, 3.92it/s, est. speed input: 473.45 toks/s, output: 8904.44 toks/s]INFO 06-16 21:24:20 metrics.py:341] Avg prompt throughput: 3737.0 tokens/s, Avg generation throughput: 8906.6 tokens/s, Running: 37 reqs, Swapped: 137 reqs, Pending: 9779 reqs, GPU KV cache usage: 76.6%, CPU KV cache usage: 77.9%. Processed prompts: 1%| | 58/10000 [00:36<31:11, 5.31it/s, est. speed input: 562.67 toks/s, output: 10551.00 toks/s]INFO 06-16 21:24:25 metrics.py:341] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 14857.3 tokens/s, Running: 36 reqs, Swapped: 125 reqs, Pending: 9779 reqs, GPU KV cache usage: 76.2%, CPU KV cache usage: 71.7%. Processed prompts: 1%| | 61/10000 [00:43<2:07:09, 1.30it/s, est. speed input: 497.61 toks/s, output: 9320.60 toks/s]WARNING 06-16 21:24:29 scheduler.py:1089] Sequence group 222 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_num_cumulative_preemption=201 INFO 06-16 21:24:30 metrics.py:341] Avg prompt throughput: 627.9 tokens/s, Avg generation throughput: 14683.6 tokens/s, Running: 39 reqs, Swapped: 129 reqs, Pending: 9770 reqs, GPU KV cache usage: 99.3%, CPU KV cache usage: 91.0%. Processed prompts: 1%| | 63/10000 [00:46<2:32:04, 1.09it/s, est. speed input: 480.87 toks/s, output: 8954.46 toks/s][rank0]: Traceback (most recent call last): [rank0]: File "/home/kataoka/Untitled3.py", line 36, in