BerriAI / litellm

Python SDK, Proxy Server (LLM Gateway) to call 100+ LLM APIs in OpenAI format - [Bedrock, Azure, OpenAI, VertexAI, Cohere, Anthropic, Sagemaker, HuggingFace, Replicate, Groq]
https://docs.litellm.ai/docs/
Other
13.86k stars 1.63k forks source link

[Bug]: async_log_success_event triggered twice when stream_options: {"include_usage": true} #5118

Closed mohittalele closed 3 months ago

mohittalele commented 3 months ago

What happened?

I am trying to test functionality of callbacks -

here is my simple fastAPI server setup -

custom_callback.py


from litellm.integrations.custom_logger import CustomLogger
import litellm

# This file includes the custom callbacks for LiteLLM Proxy
# Once defined, these can be passed in proxy_config.yaml
class MyCustomHandler(CustomLogger):
    def log_pre_api_call(self, model, messages, kwargs): 
        print(f"Pre-API Call")

    def log_post_api_call(self, kwargs, response_obj, start_time, end_time): 
        print(f"Post-API Call")

    def log_stream_event(self, kwargs, response_obj, start_time, end_time):
        print(f"On Stream")

    def log_success_event(self, kwargs, response_obj, start_time, end_time): 
        print("On Success")

    def log_failure_event(self, kwargs, response_obj, start_time, end_time): 
        print(f"On Failure")

    async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
        if "complete_streaming_response" in kwargs :
            print(kwargs["acomplete_streaming_response"])
        print(f"On Async Streaming")

    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
        if kwargs["call_type"] == "completion":
            return
        print(f"On Async Success!")
        # log: key, user, model, prompt, response, tokens, cost
        # Access kwargs passed to litellm.completion()
        model = kwargs.get("model", None)
        messages = kwargs.get("messages", None)
        user = kwargs.get("user", None)

        # Access litellm_params passed to litellm.completion(), example access `metadata`
        litellm_params = kwargs.get("litellm_params", {})
        metadata = litellm_params.get("metadata", {})   # headers passed to LiteLLM proxy, can be found here

        # Calculate cost using  litellm.completion_cost()
        # cost = litellm.completion_cost(completion_response=response_obj)
        response = response_obj
        # tokens used in response 
        usage = response_obj["usage"]

        print(
            f"""
                Model: {model},
                Messages: {messages},
                User: {user},
                Usage: {usage},
                Response: {response}
                Proxy Metadata: {metadata}
                --------------------------------------------- DONE ----------------------------
            """
        )

    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): 
        try:
            print(f"On Async Failure !")
            print("\nkwargs", kwargs)
            # Access kwargs passed to litellm.completion()
            model = kwargs.get("model", None)
            messages = kwargs.get("messages", None)
            user = kwargs.get("user", None)

            # Access litellm_params passed to litellm.completion(), example access `metadata`
            litellm_params = kwargs.get("litellm_params", {})
            metadata = litellm_params.get("metadata", {})   # headers passed to LiteLLM proxy, can be found here

            # Acess Exceptions & Traceback
            exception_event = kwargs.get("exception", None)
            traceback_event = kwargs.get("traceback_exception", None)

            # Calculate cost using  litellm.completion_cost()
            cost = litellm.completion_cost(completion_response=response_obj)
            print("now checking response obj")

            print(
                f"""
                    Model: {model},
                    Messages: {messages},
                    User: {user},
                    Cost: {cost},
                    Response: {response_obj}
                    Proxy Metadata: {metadata}
                    Exception: {exception_event}
                    Traceback: {traceback_event}
                """
            )
        except Exception as e:
            print(f"Exception: {e}")

proxy_handler_instance = MyCustomHandler()

# Set litellm.callbacks = [proxy_handler_instance] on the proxy
# need to set litellm.callbacks = [proxy_handler_instance] # on the proxy

FastAPI server -

import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from litellm import Router
from dotenv import load_dotenv
import json 
import litellm
from custom_callbacks import proxy_handler_instance
from litellm.integrations.custom_logger import CustomLogger
# Load environment variables
load_dotenv()

# Initialize the FastAPI app
app = FastAPI()

# Define the model list for the Router
model_list = [
    {
        "model_name": "some_model",
        "litellm_params": {
            "model": "openai/some_model",
            "api_key": os.getenv("OPENAI_API_KEY"),
            "api_base": "http://localhost:8000/v1",
        }
    }
]

# Initialize the Router
router = litellm.Router(model_list=model_list)
litellm.callbacks = [proxy_handler_instance]

# Define a request 
class QueryRequest(BaseModel):
    prompt: str
    model: str
    max_tokens: int = 100
    stream: bool = False
    stream_options: dict = None

# Define a response model
class QueryResponse(BaseModel):
    response: str

async def stream_response(response):
    async for chunk in response:
        yield f"{json.dumps(chunk.to_dict())}\n"
    yield " [DONE]\n"

@app.post("/query")
async def query_litellm(request: QueryRequest):
    try:
        # Generate a response using LiteLLM Router
        response = await router.acompletion(
            model=request.model,
            messages=[{"role": "user", "content": request.prompt}],
            max_tokens=request.max_tokens,
            stream=request.stream,
            stream_options=request.stream_options if request.stream  else None
        )

        if request.stream:
            return StreamingResponse(stream_response(response), media_type="text/event-stream")
        else:
            generated_text = f"{json.dumps(response.to_dict())}\n"
            return QueryResponse(response=generated_text)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Run the server
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=9090)

here is sample curl request

curl -X 'POST' \
  'http://localhost:9090/query' \
  -H 'accept: application/json' \
  -H 'Content-Type: application/json' \
  -d '{
  "prompt": "Explain global warming?",
  "model": "some_model",
  "max_tokens": 100,
  "stream": true,
  "stream_options": {
    "include_usage": true
  }
}'

async_log_success_event function is called twice when

  "stream": true,
  "stream_options": {
    "include_usage": true
  }

I would expect it to be called once. If I put stream_options = None its called one time. Am I missing something ?

Relevant log output

No response

Twitter / LinkedIn details

No response

krrishdholakia commented 3 months ago

able to repro

krrishdholakia commented 3 months ago

this is caused because the finish_reason is given but additional chunks are still being yielded, causing multiple calls

if ( "async_complete_streaming_response" in self.model_call_details ): await callback.async_log_success_event( kwargs=self.model_call_details, response_obj=self.model_call_details[ "async_complete_streaming_response" ], start_time=start_time, end_time=end_time, )

complete_streaming_response = None if self.stream: if result.choices[0].finish_reason is not None: # if it's the last chunk self.streaming_chunks.append(result)

verbose_logger.debug(f"final set of received chunks: {self.streaming_chunks}")

try: complete_streaming_response = litellm.stream_chunk_builder(