BerriAI / litellm

Python SDK, Proxy Server (LLM Gateway) to call 100+ LLM APIs in OpenAI format - [Bedrock, Azure, OpenAI, VertexAI, Cohere, Anthropic, Sagemaker, HuggingFace, Replicate, Groq]
https://docs.litellm.ai/docs/
Other
14.32k stars 1.69k forks source link

[Feature]: convenience `Enum` for `tool_choice` #6091

Open jamesbraza opened 1 month ago

jamesbraza commented 1 month ago

The Feature

From https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice, tool_choice can be:

From https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter, we see the response's finish_reason is a function of tool_choice.

It would be nice if LiteLLM provided an Enum that could handle the logic:

Alternately, perhaps LiteLLM can add an opt-in flag to acompletion that validates the finish_reason matches the input tool_choice and tools

Motivation, pitch

Enabling clients to not have to care about calculating the finish_reason, but have a validation confirming its correct

Twitter / LinkedIn details

No response

krrishdholakia commented 1 month ago

Hey @jamesbraza can you share a code example of what you expect?

jamesbraza commented 1 month ago

Sure, I recently wrote something like this:

from litellm import acompletion

TOOL_CHOICE_REQUIRED = "required"

tool_choice: Tool | str | None = TOOL_CHOICE_REQUIRED

completion_kwargs: dict[str, Any] = {}
# SEE: https://platform.openai.com/docs/guides/function-calling/configuring-function-calling-behavior-using-the-tool_choice-parameter
expected_finish_reason: set[str] = {"tool_calls"}
if isinstance(tool_choice, Tool):
    completion_kwargs["tool_choice"] = {
        "type": "function",
        "function": {"name": tool_choice.info.name},
    }
    expected_finish_reason = {"stop"}  # TODO: should this be .add("stop") too?
elif tool_choice is not None:
    completion_kwargs["tool_choice"] = tool_choice
    if tool_choice == TOOL_CHOICE_REQUIRED:
        # Even though docs say it should be just 'stop',
        # in practice 'tool_calls' shows up too
        expected_finish_reason.add("stop")

model_response = await acompletion(
    "gpt-4o",
    messages=...,
    tools=...,
    **completion_kwargs,
)

if (num_choices := len(model_response.choices)) != 1:
    raise MalformedMessageError(
        f"Expected one choice in LiteLLM model response, got {num_choices}"
        f" choices, full response was {model_response}."
    )
choice = model_response.choices[0]
if choice.finish_reason not in expected_finish_reason:
    raise MalformedMessageError(
        f"Expected a finish reason in {expected_finish_reason} in LiteLLM"
        f" model response, got finish reason {choice.finish_reason!r}, full"
        f" response was {model_response} and tool choice was {tool_choice}."
    )

# Process choice ...

Note how it has to:

  1. tool_choice can be a str. Ideally it can be an StrEnum that comes from LiteLLM
  2. Depending on the value of tool_choice, the expected_finish_reason is specified
  3. Validate the length and finish_reason in the response

I would like to upstream at least item 1 and 2 into LiteLLM, mainly because LiteLLM handles almost all of our LLM logic besides this at the moment