vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.79k stars 4.49k forks source link

[Bug]: AssertionError when using automatic prefix caching and prompt_logprobs #8268

Open novoselrok opened 2 months ago

novoselrok commented 2 months ago

Your current environment

The output of `python collect_env.py` ```text Collecting environment information... PyTorch version: 2.4.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A OS: Debian GNU/Linux 11 (bullseye) (x86_64) GCC version: (Debian 10.2.1-6) 10.2.1 20210110 Clang version: Could not collect CMake version: version 3.30.2 Libc version: glibc-2.31 Python version: 3.9.19 | packaged by conda-forge | (main, Mar 20 2024, 12:50:21) [GCC 12.3.0] (64-bit runtime) Python platform: Linux-5.10.0-30-cloud-amd64-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB GPU 1: NVIDIA A100-SXM4-40GB GPU 2: NVIDIA A100-SXM4-40GB GPU 3: NVIDIA A100-SXM4-40GB GPU 4: NVIDIA A100-SXM4-40GB GPU 5: NVIDIA A100-SXM4-40GB GPU 6: NVIDIA A100-SXM4-40GB GPU 7: NVIDIA A100-SXM4-40GB Nvidia driver version: 525.105.17 cuDNN version: Probably one of the following: /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn.so.8.9.0 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.0 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.0 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.0 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.0 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.0 /usr/local/cuda-11.8/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 46 bits physical, 48 bits virtual CPU(s): 96 On-line CPU(s) list: 0-95 Thread(s) per core: 2 Core(s) per socket: 24 Socket(s): 2 NUMA node(s): 2 Vendor ID: GenuineIntel CPU family: 6 Model: 85 Model name: Intel(R) Xeon(R) CPU @ 2.20GHz Stepping: 7 CPU MHz: 2200.226 BogoMIPS: 4400.45 Hypervisor vendor: KVM Virtualization type: full L1d cache: 1.5 MiB L1i cache: 1.5 MiB L2 cache: 48 MiB L3 cache: 77 MiB NUMA node0 CPU(s): 0-23,48-71 NUMA node1 CPU(s): 24-47,72-95 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Mitigation; Enhanced IBRS Vulnerability Spec rstack overflow: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervi Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.1.3.1 [pip3] nvidia-cuda-cupti-cu12==12.1.105 [pip3] nvidia-cuda-nvrtc-cu12==12.1.105 [pip3] nvidia-cuda-runtime-cu12==12.1.105 [pip3] nvidia-cudnn-cu12==9.1.0.70 [pip3] nvidia-cufft-cu12==11.0.2.54 [pip3] nvidia-curand-cu12==10.3.2.106 [pip3] nvidia-cusolver-cu12==11.4.5.107 [pip3] nvidia-cusparse-cu12==12.1.0.106 [pip3] nvidia-ml-py==12.555.43 [pip3] nvidia-nccl-cu12==2.20.5 [pip3] nvidia-nvjitlink-cu12==12.5.40 [pip3] nvidia-nvtx-cu12==12.1.105 [pip3] onnxruntime==1.18.1 [pip3] pyzmq==26.0.3 [pip3] sentence-transformers==3.0.1 [pip3] torch==2.4.0 [pip3] torchao==0.1 [pip3] torchtune==0.2.0.dev20240625+cpu [pip3] torchvision==0.19.0 [pip3] transformers==4.43.4 [pip3] triton==3.0.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi [conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi [conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi [conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi [conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi [conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi [conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi [conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi [conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi [conda] nvidia-ml-py 12.555.43 pypi_0 pypi [conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi [conda] nvidia-nvjitlink-cu12 12.5.40 pypi_0 pypi [conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi [conda] pyzmq 26.0.3 pypi_0 pypi [conda] sentence-transformers 3.0.1 pypi_0 pypi [conda] torch 2.4.0 pypi_0 pypi [conda] torchao 0.1 pypi_0 pypi [conda] torchtune 0.2.0.dev20240625+cpu pypi_0 pypi [conda] torchvision 0.19.0 pypi_0 pypi [conda] transformers 4.43.4 pypi_0 pypi [conda] triton 3.0.0 pypi_0 pypi ROCM Version: Could not collect Neuron SDK Version: N/A vLLM Version: 0.6.0@32e7db25365415841ebc7c4215851743fbb1bad1 vLLM Build Flags: CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled GPU Topology: GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity NUMA Affinity GPU0 X NV12 NV12 NV12 NV12 NV12 NV12 NV12 0-23,48-71 0 GPU1 NV12 X NV12 NV12 NV12 NV12 NV12 NV12 0-23,48-71 0 GPU2 NV12 NV12 X NV12 NV12 NV12 NV12 NV12 0-23,48-71 0 GPU3 NV12 NV12 NV12 X NV12 NV12 NV12 NV12 0-23,48-71 0 GPU4 NV12 NV12 NV12 NV12 X NV12 NV12 NV12 24-47,72-95 1 GPU5 NV12 NV12 NV12 NV12 NV12 X NV12 NV12 24-47,72-95 1 GPU6 NV12 NV12 NV12 NV12 NV12 NV12 X NV12 24-47,72-95 1 GPU7 NV12 NV12 NV12 NV12 NV12 NV12 NV12 X 24-47,72-95 1 Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks ```

🐛 Describe the bug

I'm having issues using automatic prefix caching with prompt_logprobs option. The first call to the generate method goes through, but the second call errors with an AssertionError.

Reproduction code:

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = LLM(model_path, tensor_parallel_size=8, dtype="bfloat16", gpu_memory_utilization=0.8, enable_prefix_caching=True)

sampling_params = SamplingParams(prompt_logprobs=1, max_tokens=1)
tokenizer = AutoTokenizer.from_pretrained(model_path)

chat_prompts = tokenizer.apply_chat_template([[{"role": "user", "content": "Test 1"}]], tokenize=False)
output = model.generate(chat_prompts, sampling_params, use_tqdm=False)

print("OK")

chat_prompts = tokenizer.apply_chat_template([[{"role": "user", "content": "Test 2"}]], tokenize=False)
output = model.generate(chat_prompts, sampling_params, use_tqdm=False) # ERROR!

Full stack trace:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[2], line 10
      7 print("OK")
      9 chat_prompts = tokenizer.apply_chat_template([[{"role": "user", "content": "Test 2"}]], tokenize=False)
---> 10 output = model.generate(chat_prompts, sampling_params, use_tqdm=False) # ERROR!

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/utils.py:1032, in deprecate_kwargs.<locals>.wrapper.<locals>.inner(*args, **kwargs)
   1025             msg += f" {additional_message}"
   1027         warnings.warn(
   1028             DeprecationWarning(msg),
   1029             stacklevel=3,  # The inner function takes up one level
   1030         )
-> 1032 return fn(*args, **kwargs)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/entrypoints/llm.py:347, in LLM.generate(self, prompts, sampling_params, prompt_token_ids, use_tqdm, lora_request, prompt_adapter_request, guided_options_request)
    338     sampling_params = SamplingParams()
    340 self._validate_and_add_requests(
    341     inputs=inputs,
    342     params=sampling_params,
    343     lora_request=lora_request,
    344     prompt_adapter_request=prompt_adapter_request,
    345     guided_options=guided_options_request)
--> 347 outputs = self._run_engine(use_tqdm=use_tqdm)
    348 return LLMEngine.validate_outputs(outputs, RequestOutput)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/entrypoints/llm.py:704, in LLM._run_engine(self, use_tqdm)
    702 total_out_toks = 0
    703 while self.llm_engine.has_unfinished_requests():
--> 704     step_outputs = self.llm_engine.step()
    705     for output in step_outputs:
    706         if output.finished:

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/engine/llm_engine.py:1551, in LLMEngine.step(self)
   1547 if allow_async_output_proc:
   1548     execute_model_req.async_callback = self.async_callbacks[
   1549         virtual_engine]
-> 1551 output = self.model_executor.execute_model(
   1552     execute_model_req=execute_model_req)
   1554 # We need to do this here so that last step's sampled_token_ids can
   1555 # be passed to the next iteration for PP.
   1556 if self.scheduler_config.is_multi_step:

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/executor/distributed_gpu_executor.py:78, in DistributedGPUExecutor.execute_model(self, execute_model_req)
     72     self.parallel_worker_tasks = self._run_workers(
     73         "start_worker_execution_loop",
     74         async_run_tensor_parallel_workers_only=True,
     75         **self.extra_execute_model_run_workers_kwargs)
     77 # Only the driver worker returns the sampling results.
---> 78 driver_outputs = self._driver_execute_model(execute_model_req)
     79 assert driver_outputs is not None
     80 return driver_outputs

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/executor/multiproc_gpu_executor.py:162, in MultiprocessingGPUExecutor._driver_execute_model(self, execute_model_req)
    154 def _driver_execute_model(
    155     self, execute_model_req: Optional[ExecuteModelRequest]
    156 ) -> Optional[List[SamplerOutput]]:
    157     """Run execute_model in the driver worker.
    158 
    159     Passing None will cause the driver to stop the model execution
    160     loop running in each of the remote workers.
    161     """
--> 162     return self.driver_worker.execute_model(execute_model_req)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/worker/worker_base.py:327, in LocalOrDistributedWorkerBase.execute_model(self, execute_model_req)
    322     if (self.observability_config is not None
    323             and self.observability_config.collect_model_execute_time):
    324         orig_model_execute_time = intermediate_tensors.tensors.get(
    325             "model_execute_time", torch.tensor(0)).item()
--> 327 output = self.model_runner.execute_model(
    328     model_input=model_input,
    329     kv_caches=self.kv_cache[worker_input.virtual_engine]
    330     if self.kv_cache is not None else None,
    331     intermediate_tensors=intermediate_tensors,
    332     num_steps=num_steps,
    333     **kwargs,
    334 )
    336 model_execute_time = time.perf_counter() - start_time
    337 if not get_pp_group().is_last_rank:
    338     # output is IntermediateTensors

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/worker/model_runner.py:1493, in ModelRunner.execute_model(self, model_input, kv_caches, intermediate_tensors, num_steps)
   1490     model_input.async_callback()
   1492 # Sample the next token.
-> 1493 output: SamplerOutput = self.model.sample(
   1494     logits=logits,
   1495     sampling_metadata=model_input.sampling_metadata,
   1496 )
   1497 if (self.observability_config is not None
   1498         and self.observability_config.collect_model_forward_time
   1499         and output is not None):
   1500     model_forward_end.synchronize()

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/model_executor/models/llama.py:447, in LlamaForCausalLM.sample(self, logits, sampling_metadata)
    442 def sample(
    443     self,
    444     logits: torch.Tensor,
    445     sampling_metadata: SamplingMetadata,
    446 ) -> Optional[SamplerOutput]:
--> 447     next_tokens = self.sampler(logits, sampling_metadata)
    448     return next_tokens

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/model_executor/layers/sampler.py:305, in Sampler.forward(self, logits, sampling_metadata)
    301 if not sampling_metadata.skip_sampler_cpu_output:
    302     # Pythonize logprobs now (GPU -> CPU); do not defer.
    303     assert not isinstance(maybe_deferred_sample_results,
    304                           SampleResultArgsType)
--> 305     prompt_logprobs, sample_logprobs = get_logprobs(
    306         logprobs, sampling_metadata, maybe_deferred_sample_results)
    308 return _build_sampler_output(
    309     maybe_deferred_sample_results,
    310     sampling_metadata,
   (...)
    313     on_device_tensors=on_device_tensors,
    314     skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)

File /opt/conda/envs/notebooks/lib/python3.9/site-packages/vllm/model_executor/layers/sampler.py:1079, in get_logprobs(logprobs, sampling_metadata, sample_results)
   1074             largest_num_logprobs = max(largest_num_logprobs,
   1075                                        sampling_params.logprobs)
   1077         use_beam_search = use_beam_search or sampling_params.use_beam_search
-> 1079     assert len(next_token_ids) == len(query_indices)
   1081 if len(query_indices) == 0:
   1082     empty_sampled_logprob: SampleLogprobs = []

AssertionError: 

Before submitting a new issue...

hibukipanim commented 1 month ago

probably similar issue to https://github.com/vllm-project/vllm/issues/5344 (same assert fails)

some more related issues come up when searching for next_token_ids: https://github.com/vllm-project/vllm/issues?q=is%3Aissue+is%3Aopen+next_token_ids

drubinstein commented 3 weeks ago

Note sure if it's any help, but I simplified the example a little bit. If the number of tokens in the prefix is > 16 and there's a full cache hit, then the assertion will trigger.

from vllm import LLM, SamplingParams, TokensPrompt

model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"

model = LLM(model_path, tensor_parallel_size=1, dtype="bfloat16", gpu_memory_utilization=0.8, enable_prefix_caching=True, enable_chunked_prefill=True,)
sampling_params = SamplingParams(prompt_logprobs=1,  max_tokens=1)

# works
# prompt = TokensPrompt(prompt_token_ids=list(range(16)))
# model.generate(prompt, sampling_params, use_tqdm=False)
# print("OK")
# model.generate(prompt, sampling_params, use_tqdm=False)
# print("OK")

# fails
prompt = TokensPrompt(prompt_token_ids=list(range(17)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print("OK")
y = model.generate(prompt, sampling_params, use_tqdm=False)
print("OK")
drubinstein commented 3 weeks ago

Another update, it looks like the crash is related to the block size. If the number of tokens in the cached prefix is > than the block size, then the assertion will be hit. 16 is the default so that's why I saw it first. As per the example below, if I use a block size of 32, then I can increase the length of TokensPrompt to 32.

Examples:

from vllm import LLM, SamplingParams, TokensPrompt

model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"

model = LLM(
    model_path,
    tensor_parallel_size=1,
    dtype="bfloat16",
    gpu_memory_utilization=0.8,
    enable_prefix_caching=True,
    enable_chunked_prefill=True,
    block_size=32
)
sampling_params = SamplingParams(prompt_logprobs=1, max_tokens=1)

# works
prompt = TokensPrompt(prompt_token_ids=list(range(31)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
y = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)

# fails
prompt = TokensPrompt(prompt_token_ids=list(range(33)))
x = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
y = model.generate(prompt, sampling_params, use_tqdm=False)
print(x[0].prompt_logprobs)
drubinstein commented 2 weeks ago

Can you try out the new version of vLLM (0.6.3.post1). I believe #9034 may have fixed this error by correctly populating Sequence.

yejingfu commented 1 week ago

The #9034 cannot fix the issue, I patched this PR but still reproduce the issue.

drubinstein commented 1 week ago

Unfortunately, I saw the same. I think I got lucky when it worked out.

ccolas commented 1 week ago

posted a fix in #3251 that solves some problems (maybe enough for you), but not all https://github.com/vllm-project/vllm/issues/3251#issuecomment-2448963097 Hope it helps

hibukipanim commented 1 week ago

@ccolas this looks great. Can you please consider opening a PR with this fix? 🙏