vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.72k stars 3.91k forks source link

[Bug]: First input (bf16) and second input (uint8) must have the same dtype! #6884

Open kalocide opened 1 month ago

kalocide commented 1 month ago

Your current environment

Collecting environment information...
PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.30.1
Libc version: glibc-2.31

Python version: 3.10.14 (main, Apr  6 2024, 18:45:05) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-177-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A40
GPU 1: NVIDIA A40

Nvidia driver version: 535.161.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      52 bits physical, 57 bits virtual
CPU(s):                             96
On-line CPU(s) list:                0-95
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              106
Model name:                         Intel(R) Xeon(R) Gold 6342 CPU @ 2.80GHz
Stepping:                           6
Frequency boost:                    enabled
CPU MHz:                            3487.740
CPU max MHz:                        2801.0000
CPU min MHz:                        800.0000
BogoMIPS:                           5600.00
Virtualization:                     VT-x
L1d cache:                          2.3 MiB
L1i cache:                          1.5 MiB
L2 cache:                           60 MiB
L3 cache:                           72 MiB
NUMA node0 CPU(s):                  0-23,48-71
NUMA node1 CPU(s):                  24-47,72-95
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear pconfig flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] flashinfer==0.0.9+cu121torch2.3
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.1
[pip3] torchvision==0.18.1
[pip3] transformers==4.43.2
[pip3] triton==2.3.1
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.3.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    NIC0    NIC1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X  PXB SYS SYS 0-23,48-71  0       N/A
GPU1    PXB  X  SYS SYS 0-23,48-71  0       N/A
NIC0    SYS SYS  X  PIX             
NIC1    SYS SYS PIX  X              

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1

🐛 Describe the bug

When running gradientai/Llama-3-8B-Instruct-Gradient-1048k, I get the following error. I haven't tried with other models, but it happens at any max-model-len.

My CLI args:

vllm serve gradientai/Llama-3-8B-Instruct-Gradient-1048k --dtype bfloat16 --kv-cache-dtype fp8 --tensor-parallel-size 2 --gpu_memory_utilization 0.99 --enforce-eager --max-model-len 786432

Traceback:

2024-07-29T04:04:29.046947428Z Traceback (most recent call last):
2024-07-29T04:04:29.046951373Z   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 46, in _log_task_completion
2024-07-29T04:04:29.046955333Z     return_value = task.result()
2024-07-29T04:04:29.046959063Z   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 637, in run_engine_loop
2024-07-29T04:04:29.046962783Z     result = task.result()
2024-07-29T04:04:29.046966493Z   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 580, in engine_step
2024-07-29T04:04:29.046970095Z     request_outputs = await self.engine.step_async(virtual_engine)
2024-07-29T04:04:29.046973653Z   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 253, in step_async
2024-07-29T04:04:29.046977513Z     output = await self.model_executor.execute_model_async(
2024-07-29T04:04:29.046981104Z   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/distributed_gpu_executor.py", line 175, in execute_model_async
2024-07-29T04:04:29.046984700Z     return await self._driver_execute_model_async(execute_model_req)
2024-07-29T04:04:29.046988533Z   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 210, in _driver_execute_model_async
2024-07-29T04:04:29.046992348Z     return await self.driver_exec_model(execute_model_req)
2024-07-29T04:04:29.046995802Z   File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run
2024-07-29T04:04:29.047008688Z     result = self.fn(*self.args, **self.kwargs)
2024-07-29T04:04:29.047012517Z   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 272, in execute_model
2024-07-29T04:04:29.047016180Z     output = self.model_runner.execute_model(
2024-07-29T04:04:29.047019768Z   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2024-07-29T04:04:29.047023399Z     return func(*args, **kwargs)
2024-07-29T04:04:29.047026957Z   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 1314, in execute_model
2024-07-29T04:04:29.047030584Z     hidden_or_intermediate_states = model_executable(
2024-07-29T04:04:29.047035102Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
2024-07-29T04:04:29.047038762Z     return self._call_impl(*args, **kwargs)
2024-07-29T04:04:29.047042259Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
2024-07-29T04:04:29.047045884Z     return forward_call(*args, **kwargs)
2024-07-29T04:04:29.047049508Z   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 422, in forward
2024-07-29T04:04:29.047053072Z     model_output = self.model(input_ids, positions, kv_caches,
2024-07-29T04:04:29.047056796Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
2024-07-29T04:04:29.047060388Z     return self._call_impl(*args, **kwargs)
2024-07-29T04:04:29.047063811Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
2024-07-29T04:04:29.047067451Z     return forward_call(*args, **kwargs)
2024-07-29T04:04:29.047070919Z   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 322, in forward
2024-07-29T04:04:29.047075899Z     hidden_states, residual = layer(
2024-07-29T04:04:29.047079528Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
2024-07-29T04:04:29.047083159Z     return self._call_impl(*args, **kwargs)
2024-07-29T04:04:29.047086762Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
2024-07-29T04:04:29.047090602Z     return forward_call(*args, **kwargs)
2024-07-29T04:04:29.047094419Z   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 245, in forward
2024-07-29T04:04:29.047097992Z     hidden_states = self.self_attn(
2024-07-29T04:04:29.047101646Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
2024-07-29T04:04:29.047105232Z     return self._call_impl(*args, **kwargs)
2024-07-29T04:04:29.047108828Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
2024-07-29T04:04:29.047112412Z     return forward_call(*args, **kwargs)
2024-07-29T04:04:29.047115962Z   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 175, in forward
2024-07-29T04:04:29.047119490Z     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
2024-07-29T04:04:29.047123003Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
2024-07-29T04:04:29.047126542Z     return self._call_impl(*args, **kwargs)
2024-07-29T04:04:29.047130067Z   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
2024-07-29T04:04:29.047133832Z     return forward_call(*args, **kwargs)
2024-07-29T04:04:29.047138187Z   File "/usr/local/lib/python3.10/dist-packages/vllm/attention/layer.py", line 97, in forward
2024-07-29T04:04:29.047141802Z     return self.impl.forward(query,
2024-07-29T04:04:29.047145562Z   File "/usr/local/lib/python3.10/dist-packages/vllm/attention/backends/xformers.py", line 598, in forward
2024-07-29T04:04:29.047149047Z     out = PagedAttention.forward_prefix(
2024-07-29T04:04:29.047152530Z   File "/usr/local/lib/python3.10/dist-packages/vllm/attention/ops/paged_attn.py", line 205, in forward_prefix
2024-07-29T04:04:29.047156127Z     context_attention_fwd(
2024-07-29T04:04:29.047168579Z   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2024-07-29T04:04:29.047172379Z     return func(*args, **kwargs)
2024-07-29T04:04:29.047176027Z   File "/usr/local/lib/python3.10/dist-packages/vllm/attention/ops/prefix_prefill.py", line 765, in context_attention_fwd
2024-07-29T04:04:29.047179686Z     _fwd_kernel[grid](
2024-07-29T04:04:29.047183387Z   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 167, in <lambda>
2024-07-29T04:04:29.047186954Z     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
2024-07-29T04:04:29.047190407Z   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 416, in run
2024-07-29T04:04:29.047193966Z     self.cache[device][key] = compile(
2024-07-29T04:04:29.047197463Z   File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 191, in compile
2024-07-29T04:04:29.047201046Z     module = src.make_ir(options)
2024-07-29T04:04:29.047204796Z   File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 117, in make_ir
2024-07-29T04:04:29.047208281Z     return ast_to_ttir(self.fn, self, options=options)
2024-07-29T04:04:29.047211876Z   File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
2024-07-29T04:04:29.047215447Z     raise CompilationError(fn.src, node, repr(e)) from e
2024-07-29T04:04:29.047218963Z triton.compiler.errors.CompilationError: at 114:24:        off_v = (
2024-07-29T04:04:29.047222463Z             bn[:, None] * stride_v_cache_bs +
2024-07-29T04:04:29.047225881Z             cur_kv_head * stride_v_cache_h +
2024-07-29T04:04:29.047229321Z             offs_d[None, :] * stride_v_cache_d +
2024-07-29T04:04:29.047232945Z             (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
2024-07-29T04:04:29.047236383Z         k = tl.load(K_cache + off_k,
2024-07-29T04:04:29.047240035Z                     mask=dim_mask[:, None] &
2024-07-29T04:04:29.047243581Z                     ((start_n + offs_n[None, :]) < cur_batch_ctx_len),
2024-07-29T04:04:29.047247035Z                     other=0.0)  # [D,N]
2024-07-29T04:04:29.047250622Z 
2024-07-29T04:04:29.047254114Z         qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)  # [M,N]
2024-07-29T04:04:29.047257570Z         qk += tl.dot(q, k)
2024-07-29T04:04:29.047261258Z                         ^
2024-07-29T04:04:29.047265742Z AssertionError('First input (bf16) and second input (uint8) must have the same dtype!')
joe-schwartz-certara commented 1 month ago

I have the same issue running with 0.5.3.post1 and kv cache in fp8 but it went away when i switched back to full kv cache. I think they induced some bug with this version of vllm but if you have the memory, taking away the kv cache dtype arg might work for you as well.

I'll also be watching this issue to see if a fix comes out and if I was right that kv cache is the culprit.

joe-schwartz-certara commented 1 month ago

Additional info: I tried three versions of llama3.1 (awq, gptq, and unquant) and all three suffered from this bug until I turned off fp8 kv cache. The other models I use with vllm were unaffected in 0.5.3.post1

kalocide commented 1 month ago

@joe-schwartz-certara Can second this; unsetting the KV cache dtype did fix it when I tried earlier. Unfortunately, it does mean that I have to run the model with less than half the batch size (from 24 to 8 before it would fit), but I am doing very small-scale inference (me and a few friends) so it does not matter for me. I could imagine this bug would be frustrating in production though

joe-schwartz-certara commented 1 month ago

@satin-spirit After my testing, my assumption is VLLM team must've fudged something small when integrating the newest models. They are speedy and smart and we will probably get a fix soon.

cieske commented 3 weeks ago

Same bug w/ microsoft/Phi-3-medium-4k-instruct when using fp8_e5m2, fp8_e4m3. Unsetting kv-cache-dtype works.