vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.8k stars 3.93k forks source link

[Bug]: Prefix Caching fails with fp8 quantized KV Cache #3880

Open amogkam opened 5 months ago

amogkam commented 5 months ago

Your current environment

PyTorch version: 2.1.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.27.9
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.58+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.2.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 535.104.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.6
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             2
On-line CPU(s) list:                0,1
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) CPU @ 2.00GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 1
Socket(s):                          1
Stepping:                           3
BogoMIPS:                           4000.33
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          32 KiB (1 instance)
L1i cache:                          32 KiB (1 instance)
L2 cache:                           1 MiB (1 instance)
L3 cache:                           38.5 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0,1
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Mitigation; PTE Inversion
Vulnerability Mds:                  Vulnerable; SMT Host state unknown
Vulnerability Meltdown:             Vulnerable
Vulnerability Mmio stale data:      Vulnerable
Vulnerability Retbleed:             Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Vulnerable
Vulnerability Spectre v1:           Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:           Vulnerable, IBPB: disabled, STIBP: disabled, PBRSB-eIBRS: Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.1.2
[pip3] torchaudio==2.2.1+cu121
[pip3] torchdata==0.7.1
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.17.1
[pip3] torchvision==0.17.1+cu121
[pip3] triton==2.1.0
[conda] Could not collectROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.4.0.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X  0-1     N/A     N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

Error in triton kernel when using prefix caching and a quantized model

from vllm import LLM, SamplingParams

prompts = ['0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '0UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '1UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '2UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '3UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA',
 '4UAqFzWsDK4FrUMp48Y3tT3QDgAL47D1qXIaSyZPaE1pu1lJo7XBetF5gIRHYH7LKBKxJsllLODfU25035HyRrY03K6JBO94XfLEN0pThnXuYgjRcJ40UA']

sampling_params = SamplingParams(temperature=0, skip_special_tokens=False)
llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", kv_cache_dtype="fp8_e5m2", max_model_len=560, enable_prefix_caching=True)
outputs = llm.generate(prompts, sampling_params)

Fails with

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py](https://localhost:8080/#) in ast_to_ttir(fn, signature, specialization, constants, debug, arch)
   1123     try:
-> 1124         generator.visit(fn.parse())
   1125     except CompilationError as e:

AssertionError: First input (fp16) and second input (uint8) must have the same dtype!

The above exception was the direct cause of the following exception:

CompilationError                          Traceback (most recent call last)
<string> in _fwd_kernel(Q, K, V, K_cache, V_cache, B_Loc, sm_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, block_size, x, Out, stride_b_loc_b, stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_k_cache_bs, stride_k_cache_h, stride_k_cache_d, stride_k_cache_bl, stride_k_cache_x, stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv, BLOCK_M, BLOCK_DMODEL, BLOCK_N, grid, num_warps, num_stages, extern_libs, stream, warmup, device, device_type)

[/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py](https://localhost:8080/#) in ast_to_ttir(fn, signature, specialization, constants, debug, arch)
   1131         if node is None:
   1132             raise
-> 1133         raise CompilationError(fn.src, node, repr(e)) from e
   1134     ret = generator.module
   1135     # module takes ownership of the context

CompilationError: at 96:24:                 (offs_d[:, None] % x) * stride_k_cache_x)
        off_v = (
            bn[:, None] * stride_v_cache_bs +
            cur_kv_head * stride_v_cache_h +
            offs_d[None, :] * stride_v_cache_d +
            (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
        k = tl.load(K_cache + off_k,
                    mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
                    other=0.0)

        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)
                        ^
AssertionError('First input (fp16) and second input (uint8) must have the same dtype!')
robertgshaw2-neuralmagic commented 5 months ago

To support this, we will need to re-write the Triton kernel to support fp8

For now, will disable prefix caching. We should open up an RFC for improvements to APC

jon-chuang commented 1 month ago

Can be closed