vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.08k stars 3.11k forks source link

[Bug]: Speculative decoding server: `ValueError: could not broadcast input array from shape (513,) into shape (512,)` #5563

Open jeffreyling opened 2 weeks ago

jeffreyling commented 2 weeks ago

Your current environment

Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-112-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.161.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             96
On-line CPU(s) list:                0-95
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8468
CPU family:                         6
Model:                              143
Thread(s) per core:                 1
Core(s) per socket:                 48
Socket(s):                          2
Stepping:                           8
BogoMIPS:                           4200.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
L1d cache:                          4.5 MiB (96 instances)
L1i cache:                          3 MiB (96 instances)
L2 cache:                           192 MiB (96 instances)
L3 cache:                           210 MiB (2 instances)
NUMA node(s):                       8
NUMA node0 CPU(s):                  0-11
NUMA node1 CPU(s):                  12-23
NUMA node2 CPU(s):                  24-35
NUMA node3 CPU(s):                  36-47
NUMA node4 CPU(s):                  48-59
NUMA node5 CPU(s):                  60-71
NUMA node6 CPU(s):                  72-83
NUMA node7 CPU(s):                  84-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] transformers==4.41.2
[pip3] triton==2.3.0
[conda] No relevant packages
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.0
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    NIC0    NIC1    NIC2    NIC3    NIC4    NIC5    NIC6    NIC7    NIC8    NIC9    CPU Affinity    NUMA Affinity     GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    PIX     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     0-11    0        N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     24-35   2        N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    SYS     SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     36-47   3        N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    SYS     SYS     SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     12-23   1        N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    SYS     SYS     SYS     SYS     SYS     PIX     PIX     SYS     SYS     SYS     48-59   4        N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX     SYS     SYS     72-83   6        N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX     SYS     84-95   7        N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX     60-71   5        N/A
NIC0    PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS
NIC1    PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX      X      SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS
NIC2    SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS     SYS     SYS
NIC3    SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS     SYS
NIC4    SYS     SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS     SYS     SYS     SYS
NIC5    SYS     SYS     SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      PIX     SYS     SYS     SYS
NIC6    SYS     SYS     SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX      X      SYS     SYS     SYS
NIC7    SYS     SYS     SYS     SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS     SYS
NIC8    SYS     SYS     SYS     SYS     SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X      SYS
NIC9    SYS     SYS     SYS     SYS     SYS     SYS     SYS     PIX     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS     SYS      X

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1
  NIC2: mlx5_2
  NIC3: mlx5_3
  NIC4: mlx5_4
  NIC5: mlx5_5
  NIC6: mlx5_6
  NIC7: mlx5_7
  NIC8: mlx5_8
  NIC9: mlx5_9

🐛 Describe the bug

I am running into an issue in the vllm server in speculative decoding mode. The server is launched with this command on an 8xH100 machine for a Mixtral 8x22B model

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
python3 -u -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--model $MODEL_NAME \
--tensor-parallel-size 8 \
--served-model-name "8x22_custom" \
--tokenizer $TOKENIZER_NAME \
--max-model-len 16384 \
--max-num-batched-tokens 16384 \
--speculative-model [ngram] \
--num-speculative-tokens 128 \
--ngram-prompt-lookup-max 32 \
--ngram-prompt-lookup-min 16 \
--speculative-max-model-len 16000 \
--use-v2-block-manager \
--enable-prefix-caching \
--disable-log-requests

After running several queries, the server runs into an error and does not recover. This takes some time, presumably because it's only a bug once the KV caching is populated

File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 503, in engine_step                          request_outputs = await self.engine.step_async()
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 235, in step_async                           output = await self.model_executor.execute_model_async(
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/executor/distributed_gpu_executor.py", line 166, in execute_model_async        return await self._driver_execute_model_async(execute_model_req)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/executor/multiproc_gpu_executor.py", line 149, in _driver_execute_model_async
    return await self.driver_exec_model(execute_model_req)
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/spec_decode/spec_decode_worker.py", line 291, in execute_model
    return self._run_speculative_decoding_step(execute_model_req,
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/spec_decode/spec_decode_worker.py", line 389, in _run_speculative_decoding_
step
    proposal_scores = self.scorer.score_proposals(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/spec_decode/batch_expansion.py", line 81, in score_proposals
    target_sampler_output = self._scorer_worker.execute_model(
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/worker/worker.py", line 272, in execute_model
    output = self.model_runner.execute_model(seq_group_metadata_list,
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 724, in execute_model
    ) = self.prepare_input_tensors(seq_group_metadata_list)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 670, in prepare_input_tensors
    ) = self._prepare_model_input(seq_group_metadata_list)
  File "/home/walden/mistral-finetune/.venv/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 516, in _prepare_model_input
    input_block_tables[i, :len(block_table)] = block_table
ValueError: could not broadcast input array from shape (513,) into shape (512,)

It seems to be an off-by-one error coming from the speculative decoding code.

Let me know if more information is needed.

cadedaniel commented 2 weeks ago

Thanks for creating the issue. Two questions:

  1. Does the problem still occur if prefix caching is disabled?
  2. Does the problem still occur if cuda graphs are disabled?
jeffreyling commented 2 weeks ago

@cadedaniel Thanks for the quick response!

  1. I tried without --enable-prefix-caching and it eventually ran into the same error.
  2. Then I tried without --enable-prefix-caching, and enabled --enforce-eager. This didn't error on the set of queries I ran.
cadedaniel commented 2 weeks ago

Thanks for trying those out so fast :)

OK the issue is very likely caused by CUDA graphs + batch expansion. This should be fixed, but currently since spec decode performance isn't good, it won't be prioritized until after that.

FYI @LiuXiaoxuanPKU another issue with batch expansion + cuda graph

Adhyyan1252 commented 2 weeks ago

Do you recommend just using --enforce-eager until this is fixed?

cadedaniel commented 2 weeks ago

yep.

cadedaniel commented 2 weeks ago

If you are blocked by this issue, the fix shouldn't be very hard. I think we simply need to configure the cuda graph max size to include the expanded batch size.

Adhyyan1252 commented 2 weeks ago

The code that is breaking is:

if use_captured_graph:
            # The shape of graph_block_tables is
            # [max batch size, max context len // block size].
            input_block_tables = self.graph_block_tables[:batch_size]
            for i, block_table in enumerate(block_tables):
                if block_table:
                    input_block_tables[i, :len(block_table)] = block_table

The issue is that len(block_table) > input_block_tables.shape[1] and the second dimension corresponds to max context len // block size. Am i misunderstanding in how this is a batch-size issue and not a context len issue?

cadedaniel commented 2 weeks ago

good point. Wonder why this is specific to spec decode then.

Does the sequence length plus proposal length go over the max model length ?

Adhyyan1252 commented 2 weeks ago

Does the sequence length plus proposal length go over the max model length ?

That was our suspicion as well so we made speculative-max-model-len shorter than the max-model-len - num-speculative-tokens but that doesnt seem to stop that issue.

--max-model-len 16384 \
--speculative-max-model-len 16000 \
--speculative-model [ngram] \
--num-speculative-tokens 128 \
--ngram-prompt-lookup-max 32 \
--ngram-prompt-lookup-min 16 \
njhill commented 2 weeks ago

@Adhyyan1252 could you see if you also get this error with vLLM 0.4.3?

jeffreyling commented 2 weeks ago

@njhill we ran into this error with 0.4.3 originally before we tried upgrading to 0.5.0.

Ximingwang-09 commented 5 days ago

try to add params : '--max-seq-len-to-capture' eqauls to max_model_len

Adhyyan1252 commented 4 days ago

Still the same issue