Lightning-AI / LitServe

Lightning-fast serving engine for any AI model of any size. Flexible. Easy. Enterprise-scale.
https://lightning.ai/docs/litserve
Apache License 2.0
2.31k stars 144 forks source link

Prometheus logger is not pickable + monitoring metrics set via self.log are not tracked #339

Open miguelalba96 opened 3 days ago

miguelalba96 commented 3 days ago

🐛 Bug

I get the following warning when using Prometheus inside of ls.Logger:

WARNING:litserve.loggers:Logger PrometheusLogger is not picklable and might not work properly.

Then the metrics that I am "observing" are not being tracked under the endpoint /metrics using self.log on the ls.LitAPI. These are the metrics I get:

# HELP python_gc_objects_collected_total Objects collected during gc
# TYPE python_gc_objects_collected_total counter
python_gc_objects_collected_total{generation="0"} 1933.0
python_gc_objects_collected_total{generation="1"} 638.0
python_gc_objects_collected_total{generation="2"} 100.0
# HELP python_gc_objects_uncollectable_total Uncollectable objects found during GC
# TYPE python_gc_objects_uncollectable_total counter
python_gc_objects_uncollectable_total{generation="0"} 0.0
python_gc_objects_uncollectable_total{generation="1"} 0.0
python_gc_objects_uncollectable_total{generation="2"} 0.0
# HELP python_gc_collections_total Number of times this generation was collected
# TYPE python_gc_collections_total counter
python_gc_collections_total{generation="0"} 679.0
python_gc_collections_total{generation="1"} 61.0
python_gc_collections_total{generation="2"} 5.0
# HELP python_info Python platform information
# TYPE python_info gauge
python_info{implementation="CPython",major="3",minor="11",patchlevel="4",version="3.11.4"} 1.0
# HELP http_server_requests_duration_seconds_total HTTP request latency in seconds
# TYPE http_server_requests_duration_seconds_total histogram
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.005",method="POST",status_code="200"} 6.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.01",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.025",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.05",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.075",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.1",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.25",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.5",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="0.75",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="1.0",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="2.5",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="5.0",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="7.5",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="10.0",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_bucket{endpoint="/predict",le="+Inf",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_count{endpoint="/predict",method="POST",status_code="200"} 7.0
http_server_requests_duration_seconds_total_sum{endpoint="/predict",method="POST",status_code="200"} 0.01563624886330217
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.005",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.01",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.025",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.05",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.075",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.1",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.25",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.5",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="0.75",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="1.0",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="2.5",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="5.0",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="7.5",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="10.0",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_bucket{endpoint="/metrics",le="+Inf",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_count{endpoint="/metrics",method="GET",status_code="200"} 2.0
http_server_requests_duration_seconds_total_sum{endpoint="/metrics",method="GET",status_code="200"} 0.006937500089406967
# HELP http_server_requests_duration_seconds_total_created HTTP request latency in seconds
# TYPE http_server_requests_duration_seconds_total_created gauge
http_server_requests_duration_seconds_total_created{endpoint="/predict",method="POST",status_code="200"} 1.729620868354843e+09
http_server_requests_duration_seconds_total_created{endpoint="/metrics",method="GET",status_code="200"} 1.72962087676791e+09
# HELP request_processing_seconds Time spent processing request
# TYPE request_processing_seconds histogram

This is my implementation for the Logger:

import litserve as ls
from prometheus_client import Histogram

class PrometheusLogger(ls.Logger):
    def __init__(self):
        super().__init__()
        self.function_duration = Histogram(
            "request_processing_seconds",
            "Time spent processing request",
            ["function_name"],
        )

    def process(self, key, value):
        self.function_duration.labels(function_name=key).observe(value)

To Reproduce

I use the following running configuration:

import litserve as ls
from prometheus_client import make_asgi_app

class CLIPModel(ls.LitAPI):
    # ... bunch of all the other methods including predict, setup, etc ...

    def get_image_embeddings(
            self,
            images,
            normalize_embedding: bool = True,
            model_version: str = "latest"
    ):
        start_time = time.perf_counter()
        with torch.no_grad():
            images = self.processor(images=images, return_tensors="pt").to(self.latest_model.device)
            if model_version == "latest":
                embedding = self.latest_model.get_image_features(**images)
            else:
                embedding = self.prev_model.get_image_features(**images)

            if normalize_embedding:
                embedding = self.normalize_embedding(embedding)
        self.log("get_image_embedding", time.perf_counter() - start_time)    # <- here I am using the PrometheusLogger
        return embedding

if __name__ == "__main__":
    prometheus_logger = monitoring.PrometheusLogger()
    prometheus_logger.mount(
        path="/metrics",
        app=make_asgi_app()
    )
    server = ls.LitServer(
        CLIPModel(),
        workers_per_device=1,
        middlewares=[monitoring.HTTPLatencyMiddleware],
        loggers=prometheus_logger,
        stream=True,
    )
    server.run(
        port=api_config.PORT,
        num_api_servers=1,
    )

where my monitoring.HTTPLatencyMiddleware is defined like this:

import os
import time

from fastapi import Request
from prometheus_client import Histogram
from starlette.middleware.base import BaseHTTPMiddleware

ENDPOINT_LABEL = "endpoint"
STATUSCODE_LABEL = "status_code"
METHOD_LABEL = "method"

HTTP_REQUEST_LATENCY = Histogram(
    "http_server_requests_duration_seconds_total",
    "HTTP request latency in seconds",
    [ENDPOINT_LABEL, STATUSCODE_LABEL, METHOD_LABEL],
    # using default buckets
)

class HTTPLatencyMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        method = request.method
        endpoint = os.path.normpath(request.url.path)
        status_code = 200

        start_time = time.perf_counter()

        try:
            # Process the request
            response = await call_next(request)
            status_code = response.status_code
        except Exception as e:
            raise e
        finally:
            # Record metrics
            duration = time.perf_counter() - start_time
            HTTP_REQUEST_LATENCY.labels(
                method=method, endpoint=endpoint, status_code=status_code
            ).observe(duration)

        return response

am I running properly correctly the ls.Server Logger and mounts? or there is something wrong?. I am following the docstrings from ls.Logger

I use prometheus-client==0.21.0" and litserve==0.2.3

aniketmaurya commented 2 days ago

hi @miguelalba96, to use Prometheus with LitServe you will need to create a multiprocessing registry so that it can collect metrics from all the inference processes. I have created the following example based on your code:

import os, time
from prometheus_client import CollectorRegistry, Histogram, make_asgi_app, multiprocess
import litserve as ls

# Set the directory for multiprocess mode
os.environ["PROMETHEUS_MULTIPROC_DIR"] = "/tmp/prometheus_multiproc_dir"

# Ensure the directory exists
if not os.path.exists("/tmp/prometheus_multiproc_dir"):
    os.makedirs("/tmp/prometheus_multiproc_dir")

# Use a multiprocess registry
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)

class PrometheusLogger(ls.Logger):
    def __init__(self):
        super().__init__()
        self.function_duration = Histogram("request_processing_seconds", "Time spent processing request", ["function_name"], registry=registry)

    def process(self, key, value):
        print("processing", key, value)
        self.function_duration.labels(function_name=key).observe(value)

class SimpleLitAPI(ls.LitAPI):
    def setup(self, device):
        self.model1 = lambda x: x**2
        self.model2 = lambda x: x**3

    def decode_request(self, request):
        return request["input"]

    def predict(self, x):
        start_time = time.perf_counter()
        squared = self.model1(x)
        cubed = self.model2(x)
        output = squared + cubed
        self.log("get_image_embedding", time.perf_counter() - start_time)
        return {"output": output}

    def encode_response(self, output):
        return {"output": output}

if __name__ == "__main__":
    prometheus_logger = PrometheusLogger()
    prometheus_logger.mount(path="/metrics", app=make_asgi_app(registry=registry))
    api = SimpleLitAPI()
    server = ls.LitServer(api, loggers=prometheus_logger)
    server.run(port=8000)

After using this code you should see the /metrics endpoint value as follows:

# HELP request_processing_seconds Multiprocess metric
# TYPE request_processing_seconds histogram
request_processing_seconds_sum{function_name="get_image_embedding"} 4.124827682971954e-06
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.005"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.01"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.025"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.05"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.075"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.1"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.25"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.5"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="0.75"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="1.0"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="2.5"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="5.0"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="7.5"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="10.0"} 2.0
request_processing_seconds_bucket{function_name="get_image_embedding",le="+Inf"} 2.0
request_processing_seconds_count{function_name="get_image_embedding"} 2.0
# HELP http_server_requests_duration_seconds_total HTTP request latency in seconds
# TYPE http_server_requests_duration_seconds_total histogram
# HELP request_processing_seconds Time spent processing request
# TYPE request_processing_seconds histogram

Also, please free to safely ignore the Picklable warning since we reconstruct the object which are not pickleable. It is just a warning in case something goes wrong when we reconstruct.