vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.62k stars 3.9k forks source link

[Bug]: Empty prompt kills vllm server (AsyncEngineDeadError: Background loop is stopped.) #7283

Open shimizust opened 1 month ago

shimizust commented 1 month ago

Your current environment

python3 collect_env.py 
Collecting environment information...
PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.30.1
Libc version: glibc-2.31

Python version: 3.10.14 (main, Apr  6 2024, 18:45:05) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.138.1-4.cm2-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB
Nvidia driver version: 525.85.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             256
On-line CPU(s) list:                0-255
Thread(s) per core:                 2
Core(s) per socket:                 64
Socket(s):                          2
NUMA node(s):                       8
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              1
Model name:                         AMD EPYC 7763 64-Core Processor
Stepping:                           1
Frequency boost:                    enabled
CPU MHz:                            3096.040
CPU max MHz:                        3529.0520
CPU min MHz:                        1500.0000
BogoMIPS:                           4900.05
Virtualization:                     AMD-V
L1d cache:                          4 MiB
L1i cache:                          4 MiB
L2 cache:                           64 MiB
L3 cache:                           512 MiB
NUMA node0 CPU(s):                  0-15,128-143
NUMA node1 CPU(s):                  16-31,144-159
NUMA node2 CPU(s):                  32-47,160-175
NUMA node3 CPU(s):                  48-63,176-191
NUMA node4 CPU(s):                  64-79,192-207
NUMA node5 CPU(s):                  80-95,208-223
NUMA node6 CPU(s):                  96-111,224-239
NUMA node7 CPU(s):                  112-127,240-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca

Versions of relevant libraries:
[pip3] flashinfer==0.0.9+cu121torch2.3
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] pyzmq==26.0.3
[pip3] torch==2.3.1
[pip3] torchvision==0.18.1
[pip3] transformers==4.43.2
[pip3] triton==2.3.1
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.3.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity
GPU0     X      48-63,176-191   3

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

Spin up the vllm server in a pod using the vllm base image (vllm/vllm-openai:v0.5.3.post1)

python3 -m vllm.entrypoints.openai.api_server --model $MODEL_PATH --port 8000 --trust-remote-code

where $MODEL_PATH points to some model. I've tried gpt2-medium and Meta-Llama-3-8B.

Generation works fine, but if you pass in an empty prompt, it immediately kills the server and is unrecoverable:

curl -X 'POST' \
  'http://localhost:8080/v1/completions' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "model": "/models/gpt2-medium",
  "prompt": "",
  "max_tokens": 50,
  "temperature": 0.9,
  "top_p": 1
}'
INFO 08-07 20:13:55 logger.py:36] Received request cmpl-662a1639b80b47539cc099bdf6347cc0-0: prompt: '', params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.9, top_p=0.9, top_k=20, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=100, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), prompt_token_ids: [], lora_request: None, prompt_adapter_request: None.
INFO 08-07 20:13:55 async_llm_engine.py:173] Added request cmpl-662a1639b80b47539cc099bdf6347cc0-0.
ERROR 08-07 20:13:55 async_llm_engine.py:56] Engine background task failed
ERROR 08-07 20:13:55 async_llm_engine.py:56] Traceback (most recent call last):
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 46, in _log_task_completion
ERROR 08-07 20:13:55 async_llm_engine.py:56]     return_value = task.result()
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 637, in run_engine_loop
ERROR 08-07 20:13:55 async_llm_engine.py:56]     result = task.result()
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 580, in engine_step
ERROR 08-07 20:13:55 async_llm_engine.py:56]     request_outputs = await self.engine.step_async(virtual_engine)
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 238, in step_async
ERROR 08-07 20:13:55 async_llm_engine.py:56]     virtual_engine].schedule()
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 978, in schedule
ERROR 08-07 20:13:55 async_llm_engine.py:56]     scheduler_outputs = self._schedule()
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 953, in _schedule
ERROR 08-07 20:13:55 async_llm_engine.py:56]     return self._schedule_default()
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 795, in _schedule_default
ERROR 08-07 20:13:55 async_llm_engine.py:56]     remaining_waiting, prefills = self._schedule_prefills(
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 690, in _schedule_prefills
ERROR 08-07 20:13:55 async_llm_engine.py:56]     num_new_tokens = self._get_num_new_tokens(seq_group,
ERROR 08-07 20:13:55 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 1232, in _get_num_new_tokens
ERROR 08-07 20:13:55 async_llm_engine.py:56]     assert num_new_tokens > 0
ERROR 08-07 20:13:55 async_llm_engine.py:56] AssertionError
Exception in callback functools.partial(<function _log_task_completion at 0x7b3f96884790>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7b3f7e2f7be0>>)
INFO 08-07 20:13:55 async_llm_engine.py:180] Aborted request cmpl-662a1639b80b47539cc099bdf6347cc0-0.
handle: <Handle functools.partial(<function _log_task_completion at 0x7b3f96884790>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7b3f7e2f7be0>>)>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 46, in _log_task_completion
    return_value = task.result()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 637, in run_engine_loop
    result = task.result()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 580, in engine_step
    request_outputs = await self.engine.step_async(virtual_engine)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 238, in step_async
    virtual_engine].schedule()
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 978, in schedule
    scheduler_outputs = self._schedule()
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 953, in _schedule
    return self._schedule_default()
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 795, in _schedule_default
    remaining_waiting, prefills = self._schedule_prefills(
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 690, in _schedule_prefills
    num_new_tokens = self._get_num_new_tokens(seq_group,
INFO:     100.99.173.128:58497 - "POST /v1/completions HTTP/1.1" 500 Internal Server Error
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 1232, in _get_num_new_tokens
    assert num_new_tokens > 0
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "uvloop/cbhandles.pyx", line 63, in uvloop.loop.Handle._run
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 58, in _log_task_completion
    raise AsyncEngineDeadError(
vllm.engine.async_llm_engine.AsyncEngineDeadError: Task finished unexpectedly. This should never happen! Please open an issue on Github. See stack trace above for theactual cause.
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/protocols/http/httptools_impl.py", line 399, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/middleware/proxy_headers.py", line 70, in __call__
    return await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/cors.py", line 85, in __call__
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/exceptions.py", line 65, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 297, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 77, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 72, in app
    response = await func(request)
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 278, in app
    raw_response = await run_endpoint_function(
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 191, in run_endpoint_function
    return await dependant.call(**values)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 144, in create_completion
    generator = await openai_serving_completion.create_completion(
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/serving_completion.py", line 175, in create_completion
    async for i, res in result_generator:
  File "/usr/local/lib/python3.10/dist-packages/vllm/utils.py", line 329, in consumer
    raise e
  File "/usr/local/lib/python3.10/dist-packages/vllm/utils.py", line 320, in consumer
    raise item
  File "/usr/local/lib/python3.10/dist-packages/vllm/utils.py", line 304, in producer
    async for item in iterator:
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 772, in generate
    async for output in self._process_request(
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 888, in _process_request
    raise e
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 884, in _process_request
    async for request_output in stream:
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 93, in __anext__
    raise result
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 46, in _log_task_completion
    return_value = task.result()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 637, in run_engine_loop
    result = task.result()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 580, in engine_step
    request_outputs = await self.engine.step_async(virtual_engine)
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 238, in step_async
    virtual_engine].schedule()
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 978, in schedule
    scheduler_outputs = self._schedule()
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 953, in _schedule
    return self._schedule_default()
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 795, in _schedule_default
    remaining_waiting, prefills = self._schedule_prefills(
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 690, in _schedule_prefills
    num_new_tokens = self._get_num_new_tokens(seq_group,
  File "/usr/local/lib/python3.10/dist-packages/vllm/core/scheduler.py", line 1232, in _get_num_new_tokens
    assert num_new_tokens > 0
AssertionError
INFO 08-07 20:13:57 metrics.py:396] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 1 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO:     100.96.103.1:43785 - "GET /health HTTP/1.1" 500 Internal Server Error
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/protocols/http/httptools_impl.py", line 399, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/middleware/proxy_headers.py", line 70, in __call__
    return await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/cors.py", line 85, in __call__
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/exceptions.py", line 65, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 297, in handle
    await self.app(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 77, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 72, in app
    response = await func(request)
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 278, in app
    raw_response = await run_endpoint_function(
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 191, in run_endpoint_function
    return await dependant.call(**values)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 88, in health
    await openai_serving_chat.engine.check_health()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 950, in check_health
    raise AsyncEngineDeadError("Background loop is stopped.")
vllm.engine.async_llm_engine.AsyncEngineDeadError: Background loop is stopped.

Expected Behavior

If an empty prompt is not allowed, I would expect a 400 invalid input response vs. a 500 that stops the server.

mgoin commented 1 month ago

This also happens for the offline LLM entrypoint:

>>> from vllm import LLM
>>> model = LLM("gpt2")
>>> model.generate("")
Processed prompts:   0%|                                                                                                                          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s][rank0]: Traceback (most recent call last):
[rank0]:   File "<stdin>", line 1, in <module>
[rank0]:   File "/home/mgoin/code/vllm/vllm/utils.py", line 996, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/mgoin/code/vllm/vllm/entrypoints/llm.py", line 339, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:   File "/home/mgoin/code/vllm/vllm/entrypoints/llm.py", line 620, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:   File "/home/mgoin/code/vllm/vllm/engine/llm_engine.py", line 1287, in step
[rank0]:     0].schedule()
[rank0]:   File "/home/mgoin/code/vllm/vllm/core/scheduler.py", line 963, in schedule
[rank0]:     scheduler_outputs = self._schedule()
[rank0]:   File "/home/mgoin/code/vllm/vllm/core/scheduler.py", line 938, in _schedule
[rank0]:     return self._schedule_default()
[rank0]:   File "/home/mgoin/code/vllm/vllm/core/scheduler.py", line 798, in _schedule_default
[rank0]:     prefills = self._schedule_prefills(budget,
[rank0]:   File "/home/mgoin/code/vllm/vllm/core/scheduler.py", line 696, in _schedule_prefills
[rank0]:     num_new_tokens = self._get_num_new_tokens(seq_group,
[rank0]:   File "/home/mgoin/code/vllm/vllm/core/scheduler.py", line 1234, in _get_num_new_tokens
[rank0]:     assert num_new_tokens > 0
[rank0]: AssertionError
youkaichao commented 1 month ago

Just curious, I think LLM always starts with a start_of_sentence token? What does empty prompt mean then?

mgoin commented 1 month ago

@youkaichao this depends on the tokenizer. I just tested llama 3.1 8b instruct and it doesn't have this issue because it has a BOS token:

>>> from vllm import LLM
>>> model = LLM("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> model.generate("")
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.37it/s, est. speed input: 7.38 toks/s, output: 118.09 toks/s]
[RequestOutput(request_id=0, prompt='', prompt_token_ids=[128000], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='Cloud bathroom mirror is designed to bring a fresh perspective to bathroom essentials. Its fog', token_ids=(16440, 15197, 18327, 374, 6319, 311, 4546, 264, 7878, 13356, 311, 15197, 59886, 13, 11699, 31349), cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1723074368.5084124, last_token_time=1723074368.5084124, first_scheduled_time=1723074368.5245223, first_token_time=1723074368.5398314, time_in_queue=0.016109943389892578, finished_time=1723074368.6597202), lora_request=None)]

You can see the prompt is empty but the prompt token ids is not prompt='', prompt_token_ids=[128000]

Either way, I think we should return an empty response or otherwise follow what openai does for empty prompt. Crashing LLM or the server is not good behavior.

youkaichao commented 1 month ago

Crashing LLM or the server is not good behavior

agree. we should never let user request crash the engine.

shimizust commented 1 month ago

Good catch that this depends on the tokenizer. The models I tested do not have the bos token defined in the tokenizer_config.json.

pseudotensor commented 3 weeks ago

Same as this: https://github.com/vllm-project/vllm/issues/7632