BerriAI / litellm

Python SDK, Proxy Server to call 100+ LLM APIs using the OpenAI format - [Bedrock, Azure, OpenAI, VertexAI, Cohere, Anthropic, Sagemaker, HuggingFace, Replicate, Groq]
https://docs.litellm.ai/docs/
Other
12k stars 1.38k forks source link

[Feature]: Custom handlers that work with pass through endpoints #4675

Closed codexceed closed 1 month ago

codexceed commented 1 month ago

The Feature

As per my current understanding, there is no way to perform pre-processing on a request made via a pass through endpoint. For instance, the async_pre_call_hook seems to only be triggered when hitting one of the supported standard model types that are configured in the config, and not for any pass through endpoint hits.

Motivation, pitch

I've been trying to leverage litellm as a proxy server that performs pre and post processing on inference requests made to standard OpenAI API endpoints, but to be routed to a custom torchserve service endpoint. This requires the following:

Upon testing this setup for the pre-processing case where I'm attempting to route a request for /chat/completions to a custom service endpoint like /predictions/model-name, there seems to be no way to trigger pre-processing of the request payload via the async_pre_call_hook function.

Custom Logger Sample

class MyCustomHandler(CustomLogger):
    def __init__(self):
        pass

    #### CALL HOOKS - proxy only #### 

    async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[
            "completion",
            "embeddings",
            "image_generation",
            "audio_transcription",
            "moderation",
            "text_completion",
        ]) -> Optional[dict | str | Exception]: 
        formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)

        if "Hello world" in formatted_prompt:
            return "This is an invalid response"

        return data

config.yaml

model_list:
  - model_name: gpt2
    litellm_params:
      model: huggingface/models/gpt2
      api_key: os.environ/HUGGINGFACE_API_KEY
      api_base: "https://api-inference.huggingface.co/models/gpt2"

litellm_settings:
  callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
  set_verbose: True
  log_raw_request_response: True

general_settings:
  pass_through_endpoints:
    - path: "/chat/completions"                                  # route you want to add to LiteLLM Proxy Server
      target: "http://localhost:8080/predictions/mistral-7b-instruct"          # URL this route should forward requests to
      headers:                                            # headers to forward to this URL
        Authorization: "Bearer os.environ/HUGGINGFACE_API_KEY" # (Optional) Auth Header to forward to your Endpoint
        content-type: application/json                    # (Optional) Extra Headers to pass to this endpoint 
        accept: application/json

Twitter / LinkedIn details

No response

krrishdholakia commented 1 month ago

hey @codexceed if you're torch serve server is already openai compatible (seems to be from example)

then you can just add it to the model_list

model_list:
  - model_name: my-model
    litellm_params:
      model: openai/<your-model-name>  # add openai/ prefix to route as OpenAI provider
      api_base: <model-api-base>       # add api base for OpenAI compatible provider
      api_key: api-key                 # api key to send your model

doc: https://docs.litellm.ai/docs/providers/openai_compatible#usage-with-litellm-proxy-server

krrishdholakia commented 1 month ago

Just saw the /prediction

Would it be easier if we just allowed you to translate from openai schema to your server's format?

Your current setup would require you to choose between all the models in model_list and your /prediction endpoint since you're replacing /chat/completions

codexceed commented 1 month ago

Just saw the /prediction

Would it be easier if we just allowed you to translate from openai schema to your server's format?

Your current setup would require you to choose between all the models in model_list and your /prediction endpoint since you're replacing /chat/completions

Hi @krrishdholakia . Yes, that's precisely the feature I'm looking for. My use case involves:

I need a proxy server that:

krrishdholakia commented 1 month ago

Great - will work on it today

Can we setup a 1:1 support channel? Would help to get feedback on implementation

DM'ed on linkedin. Let me know if slack/discord works better!

srail commented 1 month ago

+1 to the ask above - am in same situation. Have a pass-through server which handles all of our auth, and would need it to support non-OpenAI formats. I suspect just being able to provide a different api_base for non-OpenAI models, like in the example above, would do the trick.

krrishdholakia commented 1 month ago

I suspect just being able to provide a different api_base for non-OpenAI models,

@srail don't think i understood your scenario - could you explain this?

krrishdholakia commented 1 month ago

Hi @srail @codexceed this is now live - https://docs.litellm.ai/docs/providers/custom_llm_server

Here's how to try it with the sdk - (docs include proxy example too)

import litellm
from litellm import CustomLLM, completion, get_llm_provider

class MyCustomLLM(CustomLLM):
    def completion(self, *args, **kwargs) -> litellm.ModelResponse:
        return litellm.completion(
            model="gpt-3.5-turbo",
            messages=[{"role": "user", "content": "Hello world"}],
            mock_response="Hi!",
        )  # type: ignore

litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER
        {"provider": "my-custom-llm", "custom_handler": my_custom_llm}
    ]

resp = completion(
        model="my-custom-llm/my-fake-model",
        messages=[{"role": "user", "content": "Hello world!"}],
    )

assert resp.choices[0].message.content == "Hi!"
guleng commented 1 month ago

@krrishdholakia My image generation model is deployed locally and does not support the API specifications of OpenAI. I configured it according to the documentation, but it was not successful

Reference Documents: https://docs.litellm.ai/docs/providers/custom_llm_server https://docs.litellm.ai/docs/image_generation

custom_handler.py

import litellm
from litellm import CustomLLM, completion, get_llm_provider, image_generation

class MyCustomLLM(CustomLLM):
    def image_generation(self, *args, **kwargs) -> litellm.ModelResponse:
        url = "http://zbe09-7860.proxy.x-gpu.com/sdapi/v1/txt2img"
        data = {
            "prompt": kwargs.get("prompt", ""),
            "other_parameters": kwargs.get("other_parameters", {})
        }
        response = requests.post(url, json=data)
        response_data = response.json()
        return litellm.image_generation(
            choices=[litellm.Choice(text=response_data.get("result", ""))],
            usage=litellm.Usage(
                total_tokens=response_data.get("usage", {}).get("total_tokens", 0)
            )
        )

    async def aimage_generation(self, *args, **kwargs) -> litellm.ModelResponse:
        url = "http://zbe09-7860.proxy.x-gpu.com/sdapi/v1/txt2img"
        data = {
            "prompt": kwargs.get("prompt", ""),
            "other_parameters": kwargs.get("other_parameters", {})
        }
        response = requests.post(url, json=data)
        response_data = response.json()
        return litellm.image_generation(
            choices=[litellm.Choice(text=response_data.get("result", ""))],
            usage=litellm.Usage(
                total_tokens=response_data.get("usage", {}).get("total_tokens", 0)
            )
        )

my_custom_llm = MyCustomLLM()

config.yaml

model_list:
  - model_name: "my-custom-model"
    litellm_params:
      model: "my-custom-llm/my-model"
litellm_settings:
  custom_provider_map:
  - {"provider": "my-custom-llm", "custom_handler": custom_handler.my_custom_llm}
curl 'http://100.29.6.215:4000/v1/images/generations' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer sk-7TKke_a-J5RFKZ8oSJF6Sg' \
--data '{
  "model": "my-custom-model",
  "prompt": "A cute baby sea otter"
}'

Error log

litellm-1  | 07:06:50 - LiteLLM:WARNING: cost_calculator.py:738 - litellm.cost_calculator.py::response_cost_calculator - Returning None. Exception occurred - Model=1024-x-1024/my-model not found in completion cost model map/nTraceback (most recent call last):
litellm-1  |   File "/usr/local/lib/python3.11/site-packages/litellm/cost_calculator.py", line 712, in response_cost_calculator
litellm-1  |     response_cost = completion_cost(
litellm-1  |                     ^^^^^^^^^^^^^^^^
litellm-1  |   File "/usr/local/lib/python3.11/site-packages/litellm/cost_calculator.py", line 665, in completion_cost
litellm-1  |     raise e
litellm-1  |   File "/usr/local/lib/python3.11/site-packages/litellm/cost_calculator.py", line 595, in completion_cost
litellm-1  |     raise Exception(
litellm-1  | Exception: Model=1024-x-1024/my-model not found in completion cost model map
litellm-1  | 
litellm-1  | INFO:     124.205.197.30:61835 - "POST /v1/images/generations HTTP/1.1" 200 OK

What I want is Image Generation, and I want to call the local image generation model through Litellm. Can you help me write an example of a custom API for image generation? I am not a developer, but an operations specialist, so I am not very good at this. Thank you.