Open pengye91 opened 3 weeks ago
sorry, currently there is no such an interface. Could you share your benchmark? We can take a look and optimize the latency.
Hi @merrymercy,
Sure, I've tested with a simple benchmark script:
the script is trying to batch requesting the LLM to summarize the content of each chapter of classical Chinese novel "红楼梦",
run_batch
method is using SGLang FE
DSLrun_generate
method is posting the list of prompts to the /generate
urland each trying will run for 5 times, the results is like:
Time with generate api: 61.879004574997815
Time with run_batch: 77.86660224159714
As the script have already been run for several rounds with the same prompts, I think we can ignore the impact of the cache.
The script code like this:
import time
import requests
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint
import os
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-72B-Instruct", padding_side="left")
os.environ["OPENAI_API_KEY"] = "test_api_key"
set_default_backend(RuntimeEndpoint(
"http://localhost:30000",
api_key=os.environ["OPENAI_API_KEY"],
chat_template_name="chatml"
))
def get_prompts(path: str, max_tokens=32 * 1024):
contents = []
with open(path, "r") as f:
text = f.read()
contents = text.split("本章完")
prompts = []
for content in contents:
content = content.strip()
if content:
prompt = f"""
Content:
\n {content}
Instructions:
请把Content中的内容给出总结,内容字数大于 200 个字,但不超过300个字
"""
prompts.append(prompt)
return prompts
@function
def summarize(s, summary):
s += system("You are a helpful assistant.")
s += user(summary)
s += assistant(gen("summary", max_tokens=32768))
def run_batch(prompts):
start = time.time()
states = summarize.run_batch(
[{"summary": p} for p in prompts],
progress_bar=False
)
for state in states:
for m in state.messages():
if m["role"] == "assistant":
# print(m["content"])
# print("#####")
pass
end = time.time()
# print("with run_batch Time: ", end - start)
def run_generate(prompts):
url = "http://localhost:30000"
messages = [[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": text}
] for text in prompts]
prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
start = time.time()
json_data = {
"text": prompts,
"sampling_params": None,
"return_logprob": False,
"logprob_start_len": None,
"top_logprobs_num": None,
"lora_path": None,
}
responses = requests.post(f"{url}/generate", json=json_data, headers={"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"})
for response in responses.json():
# print(response)
# print("#####")
pass
end = time.time()
# print("with generate api Time: ", end - start)
if __name__ == "__main__":
import timeit
number = 5
prompts = get_prompts("红楼梦.txt")
time_with_generate = timeit.timeit(lambda: run_generate(prompts), number=number)
print(f"Time with generate api: {time_with_generate / number}")
time_with_run_batch = timeit.timeit(lambda: run_batch(prompts), number=number)
print(f"Time with run_batch: {time_with_run_batch / number}")
time_with_run_with_vllm = timeit.timeit(lambda: run_with_vllm(prompts), number=number)
print(f"Time with run_with_vllm: {time_with_run_with_vllm / number}")
the file is attached:
@pengye91 Thanks for providing the benchmark. I think HTTP does introduce some overhead. It is not easy to fix that in Python. We are considering implementing a c++ http server which can help fix some of them.
In the meanwhile, you probably have to switch to /generate
or Engine API for the extreme performance.
Checklist
Motivation
Hi team,
I’m using the
run_batch
api and found that it will send batch requests to the backend, which makes sense.but I’m wondering if there’s a way to use the
run
api but with a list of states as input, which will reduce the total latencies caused by the HTTP request-response ttl, I know that I can accomplish that by directly post to the/generate
url with a list of prompts, but I am not aware of how it can be accomplished by the frontend api likerun
.And I tested with
run_batch
and directly post to/generate
with a list of prompts, there’s around 25% of latency reduced.Related resources
No response