Open kheyer opened 3 weeks ago
Both Triton and FastAPI are valid options for serving machine learning models, but they serve different purposes:
Each solution comes with its strengths: Triton excels in scalability and performance optimization for inference, while FastAPI shines in its flexibility and ease of customization. Choose Triton if you need performance at scale, or FastAPI if you require fine-grained control over API design.
Here’s the Triton code sample that uses the /generate
endpoint to handle embedding generation:
import logging
from typing import Dict, List, AsyncGenerator
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
from pytriton.model_config import ModelConfig, Tensor
from pytriton.triton import Triton, TritonConfig
from pytriton.proxy.types import Request
# Load the Hugging Face model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Configure logger
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger("pytriton.embedding_generate_server")
def encode_embeddings(input_texts: List[str]) -> np.ndarray:
"""Generate embeddings using Hugging Face model."""
inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() # Mean pooling
return embeddings
async def generate_fn(requests: List[Request]) -> AsyncGenerator[List[Dict[str, np.ndarray]], None]:
"""Function to process generate requests and return embeddings."""
responses = []
for request in requests:
raw_embeddings = request.data.get("text_input")
if raw_embeddings is None:
raise ValueError("Missing 'text_input' in request.")
embeddings_list = [item.decode('utf-8') for item in raw_embeddings.tolist()]
embeddings = encode_embeddings(embeddings_list)
response = {"embedding": embeddings}
responses.append(response)
yield responses
if __name__ == "__main__":
triton_config = TritonConfig(http_port=8000)
with Triton(config=triton_config) as triton:
triton.bind(
model_name="embedding_model",
infer_func=generate_fn,
inputs=[
Tensor(name="text_input", dtype=np.bytes_, shape=(1,)),
],
outputs=[
Tensor(name="embedding", dtype=np.float32, shape=(-1, 384)),
],
config=ModelConfig(max_batch_size=1, batching=False)
)
triton.serve()
There is no any need to do anything with GPU at inputs level here because you will tokenize then at CPU so there is no any benefit from Triton GPU capabilities here. You have to make sure in Python your embeddings will be computed on GPU. The embeddings are just lookups in the model, so it should be fast enough on CPU.
/generate
Endpoint:To send a request to Triton’s /generate
endpoint, you can use the following curl
command:
curl -X POST http://localhost:8000/v2/models/embedding_model/generate \
-H "Content-Type: application/json" \
-d '{
"text_input": "This is an example sentence.",
"parameters": {}
}'
You can't remote the parameters
field from the request as it is required by Triton endpoint implementation. If next version will include some other fields, you will also need to work with them so you are dependent on Triton implementation in C++, which you can't easily change.
{"embedding":[0.46799784898757937,0.3234274387359619,0.29819953441619875,0.45349982380867007,0.1747879683971405,-0.019004419445991517,..., -0.13932937383651734],"model_name":"embedding_model","model_version":"1"}
As you can see response contains the embedding values and metadata about the model used for inference, which you can't change.
While Triton Inference Server is highly optimized for large-scale inference, it has limitations regarding custom endpoints:
/generate
, /infer
, etc., which are baked into its C++ backend.Here’s the FastAPI implementation that mirrors the behavior of Triton’s /generate
endpoint but with more flexibility in handling requests and defining custom logic:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
# Initialize FastAPI
app = FastAPI()
# Load the Hugging Face model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
class GenerateRequest(BaseModel):
text_input: str
parameters: Optional[Dict] = None
class GenerateResponse(BaseModel):
embedding: List[List[float]]
def encode_embeddings(input_texts: List[str]) -> np.ndarray:
"""Generate embeddings using Hugging Face model."""
inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() # Mean pooling
return embeddings
@app.post("/v2/models/embedding_model/generate", response_model=GenerateResponse)
async def generate_embedding(request: GenerateRequest):
if not request.text_input:
raise HTTPException(status_code=400, detail="Missing 'text_input' in the request.")
embeddings = encode_embeddings([request.text_input])
return GenerateResponse(embedding=embeddings.tolist())
This code is much more flexible than Triton, as you can define custom endpoints, request structures, and response formats. You can easily modify the logic to handle different request formats, include additional parameters, or extend the functionality to meet specific use cases. It is also much easier to work with and debug than Triton, as it is written in Python. You can directly itegrate payment gateways, databases, etc. in FastAPI. Non-of this will be possible in Triton so you will need to have additional server for this. You can use PyTorch here also directly and leverage GPU capabilities inside FastAPI without any problems.
To send a request to the FastAPI /generate
endpoint, use this curl
command:
curl -X POST http://localhost:8000/v2/models/embedding_model/generate \
-H "Content-Type: application/json" \
-d '{
"text_input": "This is an example sentence.",
"parameters": {}
}'
{
"embedding": [
[
0.12345, -0.56789, ..., 0.4321 # Embedding values (384-dim array)
]
]
}
This issue is stale because it has been open 21 days with no activity. Remove stale label or comment or this will be closed in 7 days.
I'm looking to use pytriton to run inference for a model. We have existing systems that send requests with the format:
However triton/pytriton requires the format:
Is there a way to change the input schema of pytriton or parse a custom request in the inference function?
Certainly one workaround would be to set up a FastAPI server that receives requests of the first format, transforms them to pytriton format, sends the request to pytriton and returns the response. However this requires adding another server, another hop, deserializing/serializing the request data, etc.
Is there a solution to this within pytriton that avoids this complexity? For example processing the raw request in the infer function.
For instance the vllm example appears to process a different request schema from the triton generate documentation using an infer function that takes a
Request
as input, but I'm not sure if that's unique to the generate endpoint.