microsoft / onnxruntime-genai

Generative AI extensions for onnxruntime
MIT License
423 stars 99 forks source link

Inference with batching is significantly slower than without batching. #714

Open Jester6136 opened 2 months ago

Jester6136 commented 2 months ago

I have implemented an inference API using ONNX Runtime and FastAPI to process multiple prompts in batches, with the goal of improving efficiency. However, I've observed that performance is significantly slower with batching compared to processing each prompt individually. When I set the batch_size back to 1, the API performs optimally.

Here is my code:

import onnxruntime_genai as og
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pydantic import BaseModel
from typing import Optional, List

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*']
)
model = og.Model('/home/rad/bags/models/cuda/cuda-int4-awq-block-128')
tokenizer = og.Tokenizer(model)

def model_run(prompts: List[str], search_options):
    input_tokens = tokenizer.encode_batch(prompts)
    params = og.GeneratorParams(model)
    params.set_search_options(**search_options)
    params.input_ids = input_tokens
    output_tokens = model.generate(params)
    out = tokenizer.decode_batch(output_tokens)
    return out

def infer(list_prompt_input: List[str], max_length = 2000):
    search_options = {
        'max_length': max_length,
        'temperature': 0.0,
        'top_p': 0.95,
        'top_k': 0.95,
        'repetition_penalty': 1.05,
    }
    chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
    # prompt = f'{chat_template.format(input=list_prompt_input)}'
    prompts = [chat_template.format(input=prompt) for prompt in list_prompt_input]
    outputs = model_run(prompts,search_options)

    result = []
    for idx,output in enumerate(outputs):
        result.append(output.split(list_prompt_input[idx])[-1].strip())
    return result

class BatchInferenceRequest(BaseModel):
    prompts: List[str]
    max_length: Optional[int] = 2000
    batch_size: int = 2

@app.post("/llm_infer")
async def llm_infer(request: BatchInferenceRequest):    # batching much slower than without batch
    max_batch_size = request.batch_size
    result = []

    import time
    start_time = time.time()
    for i in range(0, len(request.prompts), max_batch_size):
        batch_prompts = request.prompts[i:i + max_batch_size]
        outputs = infer(batch_prompts, request.max_length)
        result.extend(outputs)
    end_time = time.time()
    execution_time = end_time - start_time 
    return {"results": result,"time":execution_time}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=5555)
aciddelgado commented 1 month ago

Hello @Jester6136... have you made any progress on this issue? What model are you using?

As a sidenote, top_k should be set to an integer as it represents the top k tokens to sample from.