sgl-project / sglang

SGLang is a fast serving framework for large language models and vision language models.
https://sglang.readthedocs.io/en/latest/
Apache License 2.0
5.14k stars 362 forks source link

[Bug] incorrect input_tokens_logprob slicing in RuntimeEndpoint.select method #1257

Open jeffrey-fong opened 2 weeks ago

jeffrey-fong commented 2 weeks ago

Checklist

Describe the bug

RuntimeEndpoint.select method seems to not be slicing the correct parts of the prompt and providing the incorrect subsequence of input_token_logprobs for the choices_method class. Specifically, in this line, having prompt_len - 2 for logprob_start_len resulted in 1 preceding token included in the input_token_logprobs provided to the choices_method. This is detrimental for all ChoicesSamplingMethod subclasses, specifically for GreedyTokenSelection as it selects greedily based on the first logprob in the subsequence. May I know what is the reason for having prompt_len - 2 instead of prompt_len? I'm curious whether this is indeed a bug or I just encountered an edge case.

Additionally, this line in both TokenizerManager._handle_single_request and TokenizerManager._handle_batch_request always adds a bos_token to the prompt. This may result in double bos tokens as some models like llama3.1 and MeetKai's Functionary models automatically adds a bos_token in the Jinja chat template.

llama3.1 functionary-small-v3.2

I am working on my fork to fix both bugs currently. I can raise the PR if you guys think this should be fixed.

Reproduction

  1. Run the sglang server
    python -m sglang.launch_server --model-path meetkai/functionary-small-v3.2 --port 8000 --host 0.0.0.0 --context-length 8192
  2. I wrote this script which uses the SGLang Frontend Runtime to generate the raw response which contains tool calls and parse it into the tool calls.
    
    import json
    import random

import rich import sglang as sgl from outlines.fsm.json_schema import build_regex_from_schema from sglang.lang import choices from sglang.lang.interpreter import ProgramState from transformers import AutoTokenizer

sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:8000")) tokenizer = AutoTokenizer.from_pretrained("meetkai/functionary-small-v3.2")

@sgl.function def generate_answer(s: ProgramState, messages: list, tools: list) -> None: """ Generates an answer using a language model.

This function:
- Takes messages and tools as input
- Applies a chat template to the messages
- Generates a series of tool calls or a free-form response

The function can produce three types of outputs:
1. Free-form content (recipient is "all"):
   'all\nGlad to help! How can I assist you today?'
2. Single tool call (recipient is a tool name):
   'calculator\n{"query": "2+2"}'
3. Multiple tool calls (series of tool names and arguments):
   'calculator\n{"query": "2+2"}>>>weather\n{"location": "Los Angeles"}

The function will continue generating tool calls until it decides to stop or switches to "all" for a free-form response.
"""

# Fill the conversation
s += tokenizer.apply_chat_template(
    conversation=messages, tools=tools, add_generation_prompt=True, tokenize=False
)

recipient_index = 0
while True:

    # Select the recipient
    # Generate the next recipient
    recipient_var = f"recipient_{recipient_index}"
    s += sgl.select(
        name=recipient_var,
        choices=["all"] + [tool["function"]["name"] for tool in tools],
        choices_method=choices.greedy_token_selection,
        # choices_method=choices.unconditional_likelihood_normalized,  # greedy not working properly?
    )
    # s += sgl.gen(name=recipient_var, stop="\n")

    # Recipient name and argument Separator
    s += "\n"

    # Generate the content
    NEXT_RECIPIENT = "FINISH_MATCHED_STR: >>>"
    content_var = f"content_{recipient_index}"
    if s[recipient_var] != "all":
        # Generate arguments for the selected tool
        tool = next(t for t in tools if t["function"]["name"] == s[recipient_var])
        regex = build_regex_from_schema(json.dumps(tool["function"]["parameters"]))
        regex += r"(>>>)?"  # Next function call
        s += sgl.gen(name=content_var, regex=regex, stop=">>>")

        # Another function call exists
        print(s.get_meta_info(content_var))
        if s.get_meta_info(content_var)["finish_reason"] == NEXT_RECIPIENT:
            recipient_index += 1
            s += ">>>"
            continue
        else:
            break
    else:
        # Generate the content for "all" recipient
        s += sgl.gen(name=content_var, stop=">>>")
        if s.get_meta_info(content_var)["finish_reason"] == NEXT_RECIPIENT:
            recipient_index += 1
            s += ">>>"
            continue
        else:
            break

def non_stream_response(state: ProgramState, model): tool_calls = [] toolindex = 0 while True: if f"recipient{toolindex}" not in state: break recipient = state[f"recipient{toolindex}"] content = state[f"content{tool_index}"] if recipient != "all": id_characters = ( "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" ) toolcalls.append( { "id": "call" + "".join(random.choices(id_characters, k=29)), "function": { "name": recipient, "arguments": content, }, "type": "function", } ) content = None tool_index += 1

message = {
    "role": "assistant",
    "content": content,
    "function_call": None,
    "tool_calls": tool_calls,
}

return {
    "id": "chatcmpl-"
    + "".join(
        random.choices(
            "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", k=29
        )
    ),
    "choices": [
        {
            "finish_reason": "stop",
            "index": 0,
            "message": message,
        }
    ],
    "model": model,
    "object": "chat.completion",
    "usage": {
        "prompt_tokens": 0,
        "completion_tokens": 0,
        "total_tokens": 0,
    },
}

def main(): messages = [ { "role": "user", "content": "What is the respective weather for these 20 cities? Istanbul, Singapore, Beijing, Shanghai, London, Los Angeles, Hanoi, Moscow, Kuala Lumpur, Taipei, Berlin, Anchorage, Paris, Jakarta, Queenstown, Sydney, Auckland, Seoul, Tokyo, Toronto.", }, ] tools = [ { "type": "function", "function": { "name": "get_current_weather", "description": "Get the current weather", "parameters": { "type": "object", "properties": { "location": { "type": "string", "description": "The city and state, e.g. San Francisco, CA", } }, "required": ["location"], }, }, } ]

# Run inference
state = generate_answer.run(
    messages=messages,
    tools=tools,
    max_new_tokens=1024,
    temperature=0.0,
    stream=False,
)
rich.print(non_stream_response(state, "meetkai/functionary-small-v3.2"))

if name == "main": main()


When using GreedyTokenSelection, all 20 cities are not generated. This is different from normal generation without using SGLang Frontend Runtime. This model is capable of calling 20 functions for each of the 20 cities. If you print the token_ids of `input_token_logprobs` [here(](https://github.com/sgl-project/sglang/blob/main/python/sglang/lang/backend/runtime_endpoint.py#L247), you will realize that they are `>>>all` and `>>>get_current_weather` with the extra `>>>` which should not be there.

### Environment

Python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
CUDA available: True
GPU 0: NVIDIA RTX A6000
GPU 0 Compute Capability: 8.6
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.1, V12.1.105
CUDA Driver Version: 550.54.15
PyTorch: 2.4.0+cu121
sglang: 0.2.13
flashinfer: 0.1.5+cu121torch2.4
triton: 3.0.0
transformers: 4.44.2
requests: 2.32.3
tqdm: 4.66.5
numpy: 1.26.3
aiohttp: 3.10.5
fastapi: 0.112.2
hf_transfer: 0.1.8
huggingface_hub: 0.24.6
interegular: 0.3.3
packaging: 23.2
PIL: 10.2.0
psutil: 5.9.8
pydantic: 2.8.2
uvicorn: 0.30.6
uvloop: 0.20.0
zmq: 24.0.1
vllm: 0.5.4
multipart: 0.0.9
openai: 1.42.0
anthropic: 0.34.1
NVIDIA Topology: 
        GPU0    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      16-31,48-63     1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

ulimit soft: 1048576
zhyncs commented 2 weeks ago

@yichuan520030910320 may help take a look

yichuan520030910320 commented 2 weeks ago

For your first question: May I know what is the reason for having prompt_len - 2 instead of prompt_len? It is because we use a conservative way to calculate, First, -1 is needed because the previous token's information is required to calculate the logprobs for the next token. Second, -2 prevents two pieces of text from sometimes merging into a single token after tokenization. -2 is always right because, in the worst-case scenario, it only recalculates some logprobs, and it can ensure that certain corner cases are handled correctly.

For your second question, I think you might be right. If the chat template is added at the same time and then encoded, it's necessary to avoid the repeated addition of the BOS/EOS token.

Finally, you're welcome to contribute and submit a PR. If there are any issues, we can continue the discussion here or in the PR you provide.

jeffrey-fong commented 2 weeks ago

Thanks for the answers. For the first question, may I know if -2 is meant to address tokenizers based on SentencePiece? I'm currently still getting the extra preceding token in input_token_logprobs here. The preceding token belongs to the end of the prompt and it's logprob value is being considered regardless of which ChoicesSamplingMethod that I use. The problem is gone once I change to prompt_len - 1. I am using a tokenizer based on Tiktoken.

jeffrey-fong commented 1 week ago

Hi, any updates on this? I tried with the llama3.1-instruct-8b model as well and the input_token_logprobs also has the extra preceding token which should not be there.