vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.08k stars 4.72k forks source link

[Usage]: How does vllm server handle concurrency? #9540

Closed Jimmy-L99 closed 1 month ago

Jimmy-L99 commented 1 month ago

Your current environment

I am learning the fastapi and vllm, and try to build my own llm api_server. But when I test my code with vllm serve, vllm show the powerful inference efficiency,

The api server and test code is shown below. witch code in my fastapi should improve? Or how can I see how thevllm serve is handled?

api_server.py

EventSourceResponse.DEFAULT_PING_INTERVAL = 1000

@asynccontextmanager
async def lifespan(app: FastAPI):
    yield
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def generate_id(prefix: str, k=29) -> str:
    suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
    return f"{prefix}{suffix}"

class ModelCard(BaseModel):
    id: str = ""
    object: str = "model"
    created: int = Field(default_factory=lambda: int(time.time()))
    owned_by: str = "owner"
    root: Optional[str] = None
    parent: Optional[str] = None
    permission: Optional[list] = None

class ModelList(BaseModel):
    object: str = "list"
    data: List[ModelCard] = ["Qwen2.5-1.5b"]

class UsageInfo(BaseModel):
    prompt_tokens: int = 0
    total_tokens: int = 0
    completion_tokens: Optional[int] = 0

class ChatMessage(BaseModel):

    role: Literal["user", "assistant", "system"]
    content: Optional[str] = None

class DeltaMessage(BaseModel):
    role: Optional[Literal["user", "assistant", "system"]] = None
    content: Optional[str] = None

class ChatCompletionResponseChoice(BaseModel):
    index: int
    message: ChatMessage
    finish_reason: Literal["stop", "length"]

class ChatCompletionResponseStreamChoice(BaseModel):
    delta: DeltaMessage
    finish_reason: Optional[Literal["stop", "length"]]
    index: int

class ChatCompletionResponse(BaseModel):
    model: str
    id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
    object: Literal["chat.completion", "chat.completion.chunk"]
    choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
    created: Optional[int] = Field(default_factory=lambda: int(time.time()))
    system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
    usage: Optional[UsageInfo] = None

class ChatCompletionRequest(BaseModel):
    model: str
    messages: List[ChatMessage]
    temperature: Optional[float] = 0.1
    top_p: Optional[float] = 0.8
    max_tokens: Optional[int] = None
    stream: Optional[bool] = False
    tools: Optional[Union[dict, List[dict]]] = None
    tool_choice: Optional[Union[str, dict]] = None
    repetition_penalty: Optional[float] = 1.05

def process_messages(messages: List[ChatMessage]):
    _messages = messages
    processed_messages = []

    for m in _messages:
        if m.role == "system":
            processed_messages.append({"role": "system", "content": m.content})
        elif m.role == "user":
            processed_messages.append({"role": "user", "content": m.content})
    return processed_messages

@torch.inference_mode()
async def generate_qwen(params: dict):
    messages = params["messages"]
    temperture = float(params.get("temperature", 0.1))
    top_p = float(params.get("top_p", 0.8))
    max_tokens = int(params.get("max_tokens", 1024))
    repetition_penalty = float(params.get("repetition_penalty", 1.05))

    messages = process_messages(messages)
    inputs = tokenizer.apply_chat_template(
        messages, 
        add_generation_prompt=True, 
        tokenize=False
    )
    sampling_params = SamplingParams(
        temperature=temperture, 
        top_p=top_p, 
        repetition_penalty=repetition_penalty, 
        max_tokens=max_tokens,
        stop_token_ids=tokenizer.all_special_ids,
        ignore_eos=False,
    )

    async for output in engine.generate(
        prompt=inputs, 
        sampling_params=sampling_params,
        request_id=f"{time.time()}",
        # lora_request=lora_request,
    ):
        input_len = len(output.prompt_token_ids)
        output_len = len(output.outputs[0].token_ids)
        ret = {
            "text": output.outputs[0].text,
            "usage": {
                "prompt_tokens": input_len,
                "completion_tokens": output_len,
                "total_tokens": output_len + input_len
            },
            "finish_reason": output.outputs[0].finish_reason,
        }
        yield ret
    gc.collect()
    torch.cuda.empty_cache()

@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)

@app.get("/v1/models", response_model=ModelList)
async def list_models():
    model_card = ModelCard(id="glm-4")
    return ModelList(data=[model_card])

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
    if len(request.messages) < 1 or request.messages[-1].role == "assistant":
        raise HTTPException(status_code=400, detail="Invalid request")

    gen_params = dict(
        messages=request.messages,
        temperature=request.temperature or 0.1,
        top_p=request.top_p or 0.8,
        max_tokens=request.max_tokens or 1024,
        echo=False,
        stream=request.stream or False,
        repetition_penalty=request.repetition_penalty or 1.05,
    )
    response = ""
    async for response in generate_qwen(gen_params):
        pass

    if response["text"].startswith("\n"):
        response["text"] = response["text"][1:]
    response["text"] = response["text"].strip()

    usage = UsageInfo()
    finish_reason = "stop"

    message = ChatMessage(
        role="assistant",
        content=response["text"],
    )

    choice_data = ChatCompletionResponseChoice(
        index=0,
        message=message,
        finish_reason=finish_reason,
    )
    task_usage = UsageInfo.model_validate(response["usage"])
    for usage_key, usage_value in task_usage.model_dump().items():
        setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)

    return ChatCompletionResponse(
        model=request.model,
        choices=[choice_data],
        object="chat.completion",
        usage=usage
    )

if __name__ == "__main__":
    MODEL_PATH = os.environ.get("MODEL_PATH", "/root/ljm/LoRA/LoRA_litchi1/qwen_model/keyword_finetune_qwen2.5-1.5b-lora_lr1e-4_r16_alpha32_ld0.05_merge")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    engine_args = AsyncEngineArgs(
        model=MODEL_PATH,
        tokenizer=MODEL_PATH,
        tensor_parallel_size=2,
        dtype="bfloat16",
        trust_remote_code=True,
        gpu_memory_utilization=0.2,
        # enforce_eager=True,
        disable_log_requests=True,
        # enable_lora=True,
    )
    engine = AsyncLLMEngine.from_engine_args(engine_args)
    uvicorn.run(app, host="host", port="port", workers=1)

test.py

def load_queries_from_file(file_path):
    queries = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            data = json.loads(line)
            user_message = data.get("messages", [])[0].get("content")
            if user_message:
                queries.append(user_message)
    return queries

queries = load_queries_from_file('./dataset.jsonl')

async def send_query(query, semaphore):

    async with semaphore:
        async with httpx.AsyncClient() as client:
            try:
                response = await client.post(
                    f"http://{host}{port}/v1/chat/completions",
                    json={
                        "model": "model",
                        "messages": [{"role": "user", "content": query}],
                        "temperature": 0.7,
                    },
                    timeout=120
                )
            except Exception as e:
                print(f"Error: {e}")
            else:
                print(f"Query: {query}, Response: {response.text}")

async def test_concurrent_users(num_users, max_concurrent=100):

    tasks = []
    semaphore = asyncio.Semaphore(max_concurrent)

    for _ in range(num_users):
        query = random.choice(queries)
        task = asyncio.create_task(send_query(query, semaphore))
        tasks.append(task)

    test_start_time = time.perf_counter()
    await asyncio.gather(*tasks)
    test_end_time = time.perf_counter()

    total_time = test_end_time - test_start_time
    avg_processing_time = num_users / total_time if total_time > 0 else float('inf')

    print(f"\n--- total_time: {total_time:.2f}s ---")
    print(f"--- avg_processing_time: {avg_processing_time:.2f} per sec ---")

async def main():

    user_counts = [100]

    for count in user_counts:
        print(f"\n--- test {count} user ---")
        await test_concurrent_users(num_users=count, max_concurrent=100)

if __name__ == "__main__":
    asyncio.run(main())

How would you like to use vllm

No response

Before submitting a new issue...

robertgshaw2-neuralmagic commented 1 month ago

Check out the examples in entrypoints/api_server.py for a simple demo of this use case