vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.55k stars 4.05k forks source link

[Bug]: VLLM does not support EAGLE Spec Decode when deploying EAGLE-Qwen2-7B-Instruct model #8849

Open crownz248 opened 2 days ago

crownz248 commented 2 days ago

Your current environment

The output of `python collect_env.py` ```text PyTorch version: 2.4.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A OS: Ubuntu 22.04.4 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: version 3.30.2 Libc version: glibc-2.35 Python version: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 12.4.131 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090 GPU 1: NVIDIA GeForce RTX 4090 GPU 2: NVIDIA GeForce RTX 4090 GPU 3: NVIDIA GeForce RTX 4090 GPU 4: NVIDIA GeForce RTX 4090 GPU 5: NVIDIA GeForce RTX 4090 GPU 6: NVIDIA GeForce RTX 4090 GPU 7: NVIDIA GeForce RTX 4090 Nvidia driver version: 550.90.07 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 124 On-line CPU(s) list: 0-123 Vendor ID: AuthenticAMD Model name: AMD EPYC 7542 32-Core Processor CPU family: 23 Model: 49 Thread(s) per core: 2 Core(s) per socket: 31 Socket(s): 2 Stepping: 0 BogoMIPS: 5799.99 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid amd_dcm tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr wbnoinvd virt_ssbd arat umip rdpid arch_capabilities Hypervisor vendor: KVM Virtualization type: full L1d cache: 3.9 MiB (62 instances) L1i cache: 3.9 MiB (62 instances) L2 cache: 31 MiB (62 instances) L3 cache: 256 MiB (16 instances) NUMA node(s): 8 NUMA node0 CPU(s): 0-7,64-71 NUMA node1 CPU(s): 8-15,72-79 NUMA node2 CPU(s): 16-23,80-87 NUMA node3 CPU(s): 24-31,88-95 NUMA node4 CPU(s): 32-39,96-103 NUMA node5 CPU(s): 40-47,104-111 NUMA node6 CPU(s): 48-55,112-119 NUMA node7 CPU(s): 56-63,120-123 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection Vulnerability Spec rstack overflow: Mitigation; safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Versions of relevant libraries: [pip3] flashinfer==0.1.5+cu124torch2.4 [pip3] numpy==1.26.4 [pip3] nvidia-cublas-cu12==12.1.3.1 [pip3] nvidia-cuda-cupti-cu12==12.1.105 [pip3] nvidia-cuda-nvrtc-cu12==12.1.105 [pip3] nvidia-cuda-runtime-cu12==12.1.105 [pip3] nvidia-cudnn-cu12==9.1.0.70 [pip3] nvidia-cufft-cu12==11.0.2.54 [pip3] nvidia-curand-cu12==10.3.2.106 [pip3] nvidia-cusolver-cu12==11.4.5.107 [pip3] nvidia-cusparse-cu12==12.1.0.106 [pip3] nvidia-ml-py==12.535.161 [pip3] nvidia-nccl-cu12==2.20.5 [pip3] nvidia-nvjitlink-cu12==12.6.20 [pip3] nvidia-nvtx-cu12==12.1.105 [pip3] pyzmq==26.1.1 [pip3] torch==2.4.0 [pip3] torchvision==0.19.0 [pip3] transformers==4.43.4 [pip3] triton==3.0.0 [conda] flashinfer 0.1.5+cu124torch2.4 pypi_0 pypi [conda] numpy 1.26.4 pypi_0 pypi [conda] nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi [conda] nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi [conda] nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi [conda] nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi [conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi [conda] nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi [conda] nvidia-curand-cu12 10.3.2.106 pypi_0 pypi [conda] nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi [conda] nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi [conda] nvidia-ml-py 12.535.161 pypi_0 pypi [conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi [conda] nvidia-nvjitlink-cu12 12.6.20 pypi_0 pypi [conda] nvidia-nvtx-cu12 12.1.105 pypi_0 pypi [conda] pyzmq 26.1.1 pypi_0 pypi [conda] torch 2.4.0 pypi_0 pypi [conda] torchvision 0.19.0 pypi_0 pypi [conda] transformers 4.43.4 pypi_0 pypi [conda] triton 3.0.0 pypi_0 pypi ROCM Version: Could not collect Neuron SDK Version: N/A vLLM Version: 0.6.0@32e7db25365415841ebc7c4215851743fbb1bad1 vLLM Build Flags: CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled GPU Topology: GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X PHB PHB PHB PHB PHB PHB PHB 0-123 0-7 N/A GPU1 PHB X PHB PHB PHB PHB PHB PHB 0-123 0-7 N/A GPU2 PHB PHB X PHB PHB PHB PHB PHB 0-123 0-7 N/A GPU3 PHB PHB PHB X PHB PHB PHB PHB 0-123 0-7 N/A GPU4 PHB PHB PHB PHB X PHB PHB PHB 0-123 0-7 N/A GPU5 PHB PHB PHB PHB PHB X PHB PHB 0-123 0-7 N/A GPU6 PHB PHB PHB PHB PHB PHB X PHB 0-123 0-7 N/A GPU7 PHB PHB PHB PHB PHB PHB PHB X 0-123 0-7 N/A Legend: X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks ```

Model Input Dumps

No response

🐛 Describe the bug

I can successfully deploy llama3-8b-instruct with EAGLE. But there is a problem when deploying qwen2-7b-instruct with EAGLE.

I have converted the EAGLE-Qwen2-7B-Instruct model according to[vllm/model_executor/models/eagle.py:L126](https://github.com/vllm-project/vllm/blob/8fae5ed7f6bfd63b81310fcb24b310d9205c9687/vllm/model_executor/models/eagle.py#L126).

I tried the python code below

llm = LLM(
    model="/models/Qwen2-7B-Instruct",
    dtype='bfloat16',
    tensor_parallel_size=4,
    speculative_model="/models/EAGLE-Qwen2-7B-Instruct-vllm",
    speculative_draft_tensor_parallel_size=1,
    num_speculative_tokens=1,
    use_v2_block_manager=True,
)

I encountered another error below:

AssertionError: Attempted to load weight (torch.Size([3584])) into parameter (torch.Size([3584, 7168])) I lookup to the code [vllm/model_executor/models/eagle.py:L139](https://github.com/vllm-project/vllm/blob/8fae5ed7f6bfd63b81310fcb24b310d9205c9687/vllm/model_executor/models/eagle.py#L139) which is shown as below:

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
            ...
            elif name.startswith("fc."):
                weight_loader = getattr(self.fc.weight, "weight_loader",
                                        default_weight_loader)
                weight_loader(self.fc.weight, loaded_weight)
            ...

I think you only consider the name varieble startswith 'fc.' can only be 'fc.weight', but the fc layer of eagle-qwen2 has bias attribute, which means the name varieble can be 'fc.bias'.

I hope you can fix this in the upcoming upgrade!

Before submitting a new issue...

DarkLight1337 commented 2 days ago

Can you update your vLLM version and try again? It should have been fixed by https://github.com/vllm-project/vllm/pull/8790