triton-inference-server / pytriton

PyTriton is a Flask/FastAPI-like interface that simplifies Triton's deployment in Python environments.
https://triton-inference-server.github.io/pytriton/
Apache License 2.0
727 stars 51 forks source link

[Question] Working with custom request schemas #82

Open kheyer opened 3 weeks ago

kheyer commented 3 weeks ago

I'm looking to use pytriton to run inference for a model. We have existing systems that send requests with the format:

{
    'embedding' : [...]
}

However triton/pytriton requires the format:

{
    "inputs" : [
        {
            "name" : "embedding",
            "shape" : [1, embedding_size],
            "datatype" : "FP32",
            "data" : [...]
        }
    ]
}

Is there a way to change the input schema of pytriton or parse a custom request in the inference function?

Certainly one workaround would be to set up a FastAPI server that receives requests of the first format, transforms them to pytriton format, sends the request to pytriton and returns the response. However this requires adding another server, another hop, deserializing/serializing the request data, etc.

Is there a solution to this within pytriton that avoids this complexity? For example processing the raw request in the infer function.

For instance the vllm example appears to process a different request schema from the triton generate documentation using an infer function that takes a Request as input, but I'm not sure if that's unique to the generate endpoint.

piotrm-nvidia commented 3 weeks ago

Both Triton and FastAPI are valid options for serving machine learning models, but they serve different purposes:

Each solution comes with its strengths: Triton excels in scalability and performance optimization for inference, while FastAPI shines in its flexibility and ease of customization. Choose Triton if you need performance at scale, or FastAPI if you require fine-grained control over API design.

Triton Inference Server Solution

Here’s the Triton code sample that uses the /generate endpoint to handle embedding generation:

import logging
from typing import Dict, List, AsyncGenerator
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton, TritonConfig
from pytriton.proxy.types import Request

# Load the Hugging Face model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Configure logger
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger("pytriton.embedding_generate_server")

def encode_embeddings(input_texts: List[str]) -> np.ndarray:
    """Generate embeddings using Hugging Face model."""
    inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()  # Mean pooling
    return embeddings

async def generate_fn(requests: List[Request]) -> AsyncGenerator[List[Dict[str, np.ndarray]], None]:
    """Function to process generate requests and return embeddings."""
    responses = []

    for request in requests:
        raw_embeddings = request.data.get("text_input")

        if raw_embeddings is None:
            raise ValueError("Missing 'text_input' in request.")

        embeddings_list = [item.decode('utf-8') for item in raw_embeddings.tolist()]

        embeddings = encode_embeddings(embeddings_list)

        response = {"embedding": embeddings}
        responses.append(response)

    yield responses

if __name__ == "__main__":
    triton_config = TritonConfig(http_port=8000)

    with Triton(config=triton_config) as triton:
        triton.bind(
            model_name="embedding_model",
            infer_func=generate_fn,
            inputs=[
                Tensor(name="text_input", dtype=np.bytes_, shape=(1,)),
            ],
            outputs=[
                Tensor(name="embedding", dtype=np.float32, shape=(-1, 384)),
            ],
            config=ModelConfig(max_batch_size=1, batching=False)
        )
        triton.serve()

There is no any need to do anything with GPU at inputs level here because you will tokenize then at CPU so there is no any benefit from Triton GPU capabilities here. You have to make sure in Python your embeddings will be computed on GPU. The embeddings are just lookups in the model, so it should be fast enough on CPU.

Curl Example for Triton /generate Endpoint:

To send a request to Triton’s /generate endpoint, you can use the following curl command:

curl -X POST http://localhost:8000/v2/models/embedding_model/generate \
    -H "Content-Type: application/json" \
    -d '{
        "text_input": "This is an example sentence.",
        "parameters": {}
    }'

You can't remote the parameters field from the request as it is required by Triton endpoint implementation. If next version will include some other fields, you will also need to work with them so you are dependent on Triton implementation in C++, which you can't easily change.

Expected Triton Response:

{"embedding":[0.46799784898757937,0.3234274387359619,0.29819953441619875,0.45349982380867007,0.1747879683971405,-0.019004419445991517,..., -0.13932937383651734],"model_name":"embedding_model","model_version":"1"}

As you can see response contains the embedding values and metadata about the model used for inference, which you can't change.

Limitations of Triton Inference Server:

While Triton Inference Server is highly optimized for large-scale inference, it has limitations regarding custom endpoints:


FastAPI Solution

Here’s the FastAPI implementation that mirrors the behavior of Triton’s /generate endpoint but with more flexibility in handling requests and defining custom logic:

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch

# Initialize FastAPI
app = FastAPI()

# Load the Hugging Face model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

class GenerateRequest(BaseModel):
    text_input: str
    parameters: Optional[Dict] = None

class GenerateResponse(BaseModel):
    embedding: List[List[float]]

def encode_embeddings(input_texts: List[str]) -> np.ndarray:
    """Generate embeddings using Hugging Face model."""
    inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()  # Mean pooling
    return embeddings

@app.post("/v2/models/embedding_model/generate", response_model=GenerateResponse)
async def generate_embedding(request: GenerateRequest):
    if not request.text_input:
        raise HTTPException(status_code=400, detail="Missing 'text_input' in the request.")

    embeddings = encode_embeddings([request.text_input])

    return GenerateResponse(embedding=embeddings.tolist())

This code is much more flexible than Triton, as you can define custom endpoints, request structures, and response formats. You can easily modify the logic to handle different request formats, include additional parameters, or extend the functionality to meet specific use cases. It is also much easier to work with and debug than Triton, as it is written in Python. You can directly itegrate payment gateways, databases, etc. in FastAPI. Non-of this will be possible in Triton so you will need to have additional server for this. You can use PyTorch here also directly and leverage GPU capabilities inside FastAPI without any problems.

Curl Example for FastAPI Endpoint:

To send a request to the FastAPI /generate endpoint, use this curl command:

curl -X POST http://localhost:8000/v2/models/embedding_model/generate \
    -H "Content-Type: application/json" \
    -d '{
        "text_input": "This is an example sentence.",
        "parameters": {}
    }'

Expected FastAPI Response:

{
  "embedding": [
    [
      0.12345, -0.56789, ..., 0.4321  # Embedding values (384-dim array)
    ]
  ]
}

Advantages of FastAPI:

github-actions[bot] commented 21 hours ago

This issue is stale because it has been open 21 days with no activity. Remove stale label or comment or this will be closed in 7 days.