vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
21.93k stars 3.1k forks source link

[Bug]: Unexpected prompt token logprob behaviors of llama 2 when setting echo=True for openai-api server #5334

Open fywalter opened 3 weeks ago

fywalter commented 3 weeks ago

Your current environment

The output of `python collect_env.py`

Collecting environment information... PyTorch version: 2.3.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0 Clang version: 10.0.0-4ubuntu1 CMake version: version 3.29.3 Libc version: glibc-2.31

Python version: 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-5.4.0-169-generic-x86_64-with-glibc2.31 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA RTX A6000 GPU 1: NVIDIA RTX A6000 GPU 2: NVIDIA RTX A6000 GPU 3: NVIDIA RTX A6000 GPU 4: NVIDIA RTX A6000 GPU 5: NVIDIA RTX A6000 GPU 6: NVIDIA RTX A6000 GPU 7: NVIDIA RTX A6000

Nvidia driver version: 545.23.08 cuDNN version: Probably one of the following: /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8.7.0 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.7.0 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.7.0 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.7.0 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.7.0 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.7.0 /usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.7.0 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn.so.8.9.2 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn.so.9.1.1 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.2 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.2 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.2 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.2 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.2 /usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.2 /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudnn.so.8.9.2 /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudnn.so.9.1.1 /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.2 /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.2 /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.2 /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.2 /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.2 /usr/local/cuda-12.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.2 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 48 bits physical, 48 bits virtual CPU(s): 192 On-line CPU(s) list: 0-191 Thread(s) per core: 2 Core(s) per socket: 48 Socket(s): 2 NUMA node(s): 2 Vendor ID: AuthenticAMD CPU family: 25 Model: 1 Model name: AMD EPYC 7643 48-Core Processor Stepping: 1 Frequency boost: enabled CPU MHz: 2770.121 CPU max MHz: 2300.0000 CPU min MHz: 1500.0000 BogoMIPS: 4600.13 Virtualization: AMD-V L1d cache: 3 MiB L1i cache: 3 MiB L2 cache: 48 MiB L3 cache: 512 MiB NUMA node0 CPU(s): 0-47,96-143 NUMA node1 CPU(s): 48-95,144-191 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] nvidia-nccl-cu12==2.20.5 [pip3] torch==2.3.0 [pip3] transformers==4.41.1 [pip3] triton==2.3.0 [pip3] vllm_nccl_cu12==2.18.1.0.4.0 [conda] numpy 1.26.4 pypi_0 pypi [conda] nvidia-nccl-cu12 2.20.5 pypi_0 pypi [conda] torch 2.3.0 pypi_0 pypi [conda] transformers 4.41.1 pypi_0 pypi [conda] triton 2.3.0 pypi_0 pypi [conda] vllm-nccl-cu12 2.18.1.0.4.0 pypi_0 pypi ROCM Version: Could not collect Neuron SDK Version: N/A vLLM Version: 0.4.2 vLLM Build Flags: CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled GPU Topology: GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity NUMA Affinity GPU NUMA ID GPU0 X NV4 SYS SYS SYS SYS SYS SYS 0-47,96-143 0 N/A GPU1 NV4 X SYS SYS SYS SYS SYS SYS 0-47,96-143 0 N/A GPU2 SYS SYS X NV4 SYS SYS SYS SYS 0-47,96-143 0 N/A GPU3 SYS SYS NV4 X SYS SYS SYS SYS 0-47,96-143 0 N/A GPU4 SYS SYS SYS SYS X NV4 SYS SYS 48-95,144-191 1 N/A GPU5 SYS SYS SYS SYS NV4 X SYS SYS 48-95,144-191 1 N/A GPU6 SYS SYS SYS SYS SYS SYS X SYS 48-95,144-191 1 N/A GPU7 SYS SYS SYS SYS SYS SYS SYS X 48-95,144-191 1 N/A

Legend:

X = Self SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI) NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU) PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge) PIX = Connection traversing at most a single PCIe bridge NV# = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

I set up the openai api server using

python -m vllm.entrypoints.openai.api_server \
    --model meta-llama/Llama-2-7b-hf \
    --max-logprobs 100 \
    --host 0.0.0.0 \
    --port xxxx

Then I try to get the logprobs of of tokens in a given prompt by setting logprobs=1 and echo=True using the following code

from openai import OpenAI
model = "meta-llama/Llama-2-7b-hf"
openai_api_key = "EMPTY"
client = OpenAI(
    api_key=openai_api_key,
    base_url=HOST_DICT[model],
)
logprobs = 1
prompt = "\n1. Carol has 20"
completion = client.completions.create(model=model, 
                                       prompt=prompt,
                                       max_tokens=5,
                                       temperature=0,
                                       logprobs=logprobs,
                                       echo=True)
print("vLLM Completion text:", completion.choices[0].text)
print("vLLM tokens:", completion.choices[0].logprobs.tokens)
print("vLLM Completion logprobs:", completion.choices[0].logprobs)

The generated texts look great, but the logprobs seem to be really strange:

vLLM Completion text: 
1. Carol has 2000 books in her
vLLM tokens: ['<s>', '', '\n', '1', '.', '\n Carol', '\n\n has', '\n\n\n ', '\n\n\n\n2', '\n\n\n\n\n0', '0', '0', ' books', ' in', ' her']
vLLM Completion logprobs: Logprobs(text_offset=[0, 3, 3, 4, 5, 6, 13, 19, 23, 28, 34, 35, 36, 42, 45], token_logprobs=[None, -4.1065497398376465, -4.179874420166016, -5.703912258148193, -0.9323832988739014, -10.607088088989258, -6.339632987976074, -4.3729047775268555, -1.8068292140960693, -1.6547526121139526, -2.1797218322753906, -2.6753900051116943, -2.8350002765655518, -1.0952296257019043, -0.1524326652288437], tokens=['<s>', '', '\n', '1', '.', '\n Carol', '\n\n has', '\n\n\n ', '\n\n\n\n2', '\n\n\n\n\n0', '0', '0', ' books', ' in', ' her'], top_logprobs=[None, {'': -4.1065497398376465, 'Tags': -2.5245184898376465}, {'\n': -4.179874420166016, '1': -1.4376866817474365}, {'1': -5.703912258148193, '\n': -1.1257872581481934}, {'.': -0.9323832988739014}, {'\n Carol': -10.607088088989258, '\n The': -2.757479190826416}, {'\n\n has': -6.339632987976074, '\n\nyn': -1.4314298629760742}, {'\n\n\n ': -4.3729047775268555, '\n\n\n a': -1.7049362659454346}, {'\n\n\n\n2': -1.8068292140960693, '\n\n\n\n1': -1.3771417140960693}, {'\n\n\n\n\n0': -1.6547526121139526}, {'0': -2.1797218322753906}, {'0': -2.6753900051116943}, {' books': -2.8350002765655518}, {' in': -1.0952296257019043}, {' her': -0.1524326652288437}])

It returns "tokens" like '\n\n\n\n2' , which is neither in the original prompt or even a single token in the llama-2 vocabulary. On the other hand, I tried using some other AI provider (fireworks) and the behavior is expected:

from openai import OpenAI

logprobs = 1
with open("fireworks-key.txt", "r") as f:
    fireworks_key = f.read().strip()

client_fireworks = OpenAI(
    base_url = "https://api.fireworks.ai/inference/v1",
    api_key=fireworks_key,
)
prompt = "\n1. Carol has 20"

completion = client_fireworks.completions.create(model="accounts/fireworks/models/llama-v2-7b",
                                                    prompt=prompt,
                                                    max_tokens=10,
                                                    temperature=0,
                                                    logprobs=logprobs,
                                                    echo=True)
print("Fireworks Completion text:", completion.choices[0].text)
print("Fireworks tokens:", completion.choices[0].logprobs.tokens)
print("Fireworks Completion logprobs:", completion.choices[0].logprobs)

Outputs:

Fireworks Completion text:  
1. Carol has 2000 books in her library. She has 
Fireworks tokens: ['', ' ', '\n', '1', '.', ' Carol', ' has', ' ', '2', '0', '0', '0', ' books', ' in', ' her', ' library', '.', ' She', ' has', ' ']
Fireworks Completion logprobs: Logprobs(text_offset=[0, 0, 1, 2, 3, 4, 10, 14, 15, 16, 17, 18, 19, 25, 28, 32, 40, 41, 45, 49], token_logprobs=[0.0, -3.16015625, -9.4765625, -5.703125, -0.92675781, -10.6015625, -6.33984375, -4.375, -1.80371094, -1.65429688, -2.17938995, -2.67908955, -2.83534431, -1.09457016, -0.15235297, -0.66951698, -0.24534534, -1.22056556, -1.91010427, -1.18275023], tokens=['', ' ', '\n', '1', '.', ' Carol', ' has', ' ', '2', '0', '0', '0', ' books', ' in', ' her', ' library', '.', ' She', ' has', ' '], top_logprobs=[{' ⁇ ': 0.0}, {' Tags': -2.546875}, {'1': -1.43945312}, {'\n': -1.12597656}, {'.': -0.92675781}, {' The': -2.75976562}, {'yn': -1.43457031}, {' a': -1.70605469}, {'1': -1.38183594}, {'0': -1.65429688}, {'0': -2.17938995}, {'0': -2.67908955}, {' books': -2.83534431}, {' in': -1.09457016}, {' her': -0.15235297}, {' library': -0.66951698}, {'.': -0.24534534}, {' She': -1.22056556}, {' has': -1.91010427}, {' ': -1.18275023}], token_ids=[1, 29871, 13, 29896, 29889, 8562, 756, 29871, 29906, 29900, 29900, 29900, 8277, 297, 902, 3489, 29889, 2296, 756, 29871])

I tried llama-3 using vLLM and it works correctly, why is this happening to Llama-2 (I tried other sizes of Llama-2 and all have this problem)?

PastelBelem8 commented 3 weeks ago

I have noticed the same issue. It's been preventing me from using vLLM on my day to day work.

fywalter commented 6 days ago

I think this problem is also related to #4772. After some investigation, I found the decoded token ids are correct for the prompt tokens while the decoded tokens seem to be wrong. Since the generated tokens behave correctly, I tried to fix the problem by replacing the code for processing the prompt tokens with the code to process generated tokens and this fixed the problem: Specifically I modified decode_prompt_logprobs_inplace in /vllm/transformers_utils/detokenizer.py:

 def decode_prompt_logprobs_inplace(
            self, seq_group: SequenceGroup,
            prompt_logprobs: List[Optional[Dict[int, Logprob]]]) -> None:
        """Decodes the logprobs for the prompt of a sequence group.

        Args:
            seq_group: The sequence group to decode.
            prompt_logprobs: The logprobs to decode.

        Returns:
            The prompt logprobs with the decoded tokens.
        """
        prms = seq_group.sampling_params
        # We can pick any sequence for the prompt.
        seq = next(iter(seq_group.seqs_dict.values()))
        # Only prompt, without the generated token.
        all_token_ids = seq.get_token_ids()
        prompt_token_ids = all_token_ids[:-1]
        tokenizer = self.get_tokenizer_for_seq(seq)
        prefix_offset = 0
        read_offset = 0
        next_iter_prefix_offset = 0
        next_iter_read_offset = 0
        next_iter_tokens = []
        prev_tokens = None
        for token_position, prompt_logprobs_for_token in enumerate(
                prompt_logprobs):
            if not prompt_logprobs_for_token:
                continue
            for token_id, sample_logprob in prompt_logprobs_for_token.items():
                if (sample_logprob.decoded_token is None
                        and token_id != INVALID_TOKEN_ID):
                    ###########################################
                    # debug for llama 2
                    # the prompt logprobs are incorrect for llama 2 models

                    # +++++++++++++++++++++++++++++++++++++++++
                    # original code
                    # prompt_token_ids_with_token = (
                    #     prompt_token_ids[:token_position] + [token_id])
                    # +++++++++++++++++++++++++++++++++++++++++

                    # ========================================
                    # attempt to fix the bug
                    # looks like for llama 2, the bos token is not included in the prompt_token_ids
                    # and the prev_tokens and offsets are not correct, since the generated tokens look correct 
                    # use the same code as in decode_sequence_inplace

                    if "llama-2" in tokenizer.name_or_path.lower():
                        prompt_token_ids_with_token = (
                            prompt_token_ids[:token_position+1] + [token_id])   # include the bos token
                        (prev_tokens, prefix_offset,
                            read_offset) = convert_prompt_ids_to_tokens(
                            tokenizer=tokenizer,
                            prompt_ids=prompt_token_ids[:token_position+1],     # include the bos token
                            skip_special_tokens=prms.skip_special_tokens,
                        )
                    else:
                        prompt_token_ids_with_token = (
                            prompt_token_ids[:token_position] + [token_id])
                    # ========================================
                    # debug
                    # print(f"prompt_token_ids_with_token: {prompt_token_ids_with_token}")
                    # import pdb; pdb.set_trace()
                    ###########################################
                    (new_tokens, new_text, new_prefix_offset,
                     new_read_offset) = detokenize_incrementally(
                         tokenizer=tokenizer,
                         all_input_ids=prompt_token_ids_with_token,
                         prev_tokens=prev_tokens,
                         prefix_offset=prefix_offset,
                         read_offset=read_offset,
                         skip_special_tokens=prms.skip_special_tokens,
                         spaces_between_special_tokens=prms.
                         spaces_between_special_tokens,
                     )
                    # import pdb; pdb.set_trace()
                    sample_logprob.decoded_token = new_text

                    # Use the offsets & prev tokens corresponding to
                    # real tokens to ensure detokenization is consistent
                    # actual with prompt.
                    if token_id == all_token_ids[token_position]:
                        next_iter_prefix_offset = new_prefix_offset
                        next_iter_read_offset = new_read_offset
                        next_iter_tokens = new_tokens

            # Advance to the next token position.
            prefix_offset = next_iter_prefix_offset
            read_offset = next_iter_read_offset
            if prev_tokens is None:
                prev_tokens = next_iter_tokens
            else:
                prev_tokens.extend(next_iter_tokens)

After fixing the bug: image

It is not exactly clear why to use different methods for processing prompt and generated tokens. Should I start a pull request?