sgl-project / sglang

SGLang is a fast serving framework for large language models and vision language models.
https://sgl-project.github.io/
Apache License 2.0
5.9k stars 479 forks source link

[Feature] Using frontend APIs but passing a list of prompts in `run` rather than `run_batch` #1624

Open pengye91 opened 3 weeks ago

pengye91 commented 3 weeks ago

Checklist

Motivation

Hi team,

I’m using the run_batch api and found that it will send batch requests to the backend, which makes sense.

but I’m wondering if there’s a way to use the run api but with a list of states as input, which will reduce the total latencies caused by the HTTP request-response ttl, I know that I can accomplish that by directly post to the /generate url with a list of prompts, but I am not aware of how it can be accomplished by the frontend api like run.

And I tested with run_batch and directly post to /generate with a list of prompts, there’s around 25% of latency reduced.

Related resources

No response

merrymercy commented 3 weeks ago

sorry, currently there is no such an interface. Could you share your benchmark? We can take a look and optimize the latency.

pengye91 commented 3 weeks ago

Hi @merrymercy,

Sure, I've tested with a simple benchmark script:

the script is trying to batch requesting the LLM to summarize the content of each chapter of classical Chinese novel "红楼梦",

and each trying will run for 5 times, the results is like:

Time with generate api: 61.879004574997815
Time with run_batch: 77.86660224159714

As the script have already been run for several rounds with the same prompts, I think we can ignore the impact of the cache.

The script code like this:

import time
import requests
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint
import os
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-72B-Instruct", padding_side="left")

os.environ["OPENAI_API_KEY"] = "test_api_key"

set_default_backend(RuntimeEndpoint(
    "http://localhost:30000", 
    api_key=os.environ["OPENAI_API_KEY"],
    chat_template_name="chatml"
    ))

def get_prompts(path: str, max_tokens=32 * 1024):
    contents = []
    with open(path, "r") as f:
        text = f.read()
        contents = text.split("本章完")
    prompts = []    
    for content in contents:
        content = content.strip()
        if content:
            prompt = f"""
            Content:
            \n        {content}
            Instructions:
            请把Content中的内容给出总结,内容字数大于 200 个字,但不超过300个字
            """
            prompts.append(prompt)
    return prompts

@function
def summarize(s, summary):
    s += system("You are a helpful assistant.")
    s += user(summary)
    s += assistant(gen("summary", max_tokens=32768))

def run_batch(prompts):
    start = time.time()
    states = summarize.run_batch(
        [{"summary": p} for p in prompts],
        progress_bar=False
    )

    for state in states:
        for m in state.messages():
            if m["role"] == "assistant":
                # print(m["content"])
                # print("#####")
                pass

    end = time.time()
    # print("with run_batch Time: ", end - start)

def run_generate(prompts):
    url = "http://localhost:30000"
    messages = [[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": text}
    ] for text in prompts]

    prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
    start = time.time()
    json_data = {
            "text": prompts,
            "sampling_params": None,
            "return_logprob": False,
            "logprob_start_len": None,
            "top_logprobs_num": None,
            "lora_path": None,
        }
    responses = requests.post(f"{url}/generate", json=json_data, headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"})
    for response in responses.json():
        # print(response)
        # print("#####")
        pass
    end = time.time()
    # print("with generate api Time: ", end - start)

if __name__ == "__main__":
    import timeit
    number = 5
    prompts = get_prompts("红楼梦.txt")
    time_with_generate = timeit.timeit(lambda: run_generate(prompts), number=number)
    print(f"Time with generate api: {time_with_generate / number}")
    time_with_run_batch = timeit.timeit(lambda: run_batch(prompts), number=number)
    print(f"Time with run_batch: {time_with_run_batch / number}")
    time_with_run_with_vllm = timeit.timeit(lambda: run_with_vllm(prompts), number=number)
    print(f"Time with run_with_vllm: {time_with_run_with_vllm / number}")

the file is attached:

红楼梦.txt

merrymercy commented 2 weeks ago

@pengye91 Thanks for providing the benchmark. I think HTTP does introduce some overhead. It is not easy to fix that in Python. We are considering implementing a c++ http server which can help fix some of them.

In the meanwhile, you probably have to switch to /generate or Engine API for the extreme performance.