huggingface / text-generation-inference

Large Language Model Text Generation Inference
http://hf.co/docs/text-generation-inference
Apache License 2.0
8.64k stars 1k forks source link

OpenAI-compatible API has a discrepancy with original OpenAI API when using tool calls #2136

Closed DaMagus26 closed 2 weeks ago

DaMagus26 commented 2 months ago

System Info

Recently I've been trying to use Qwen2 72b through Docker 2.0.4 from Jupyter Notebook. I tried using LangChain langchain_openai.chat_models.ChatOpenAI client with custom tools. The completion request succeeded, however when LangChain was parsing the response I got an error. After a little digging I found that TGI and OpenAI send response in different formats when tool calls are involved.

When I sent a POST request to OpenAI with tools, I got json in response:

...
"tool_calls": [
          {
            "id": "0",
            "type": "function",
            "function": {
              "description": null,
              "name": "retrieve_payment_status",
              "arguments": "{\"transaction_id\": \"T1001\"}"
            }
...

and this is what I got from my TGI service:

...
"tool_calls": [
          {
            "id": "0",
            "type": "function",
            "function": {
              "description": null,
              "name": "retrieve_payment_status",
              "arguments": {
                "transaction_id": "T1001"
              }
            }
...

Apparently OpenAI formats tool_calls["function"]["arguments"] field as a JSON string, and TGI parses it into a valid JSON object. I get that this is probably the mistake of OpenAI, however this discrepancy makes LangChain client (which is also OpenAI-compatible) raise an error when parsing TGI response.

Information

Tasks

Reproduction

Source code that raised an error:

from langchain.tools import BaseTool

class RetrievePaymentStatusSchema(BaseModel):
    transaction_id: str = Field(desc='The transaction id.')

class RetrievePaymentStatus(BaseTool):
    name: str = 'retrieve_payment_status'
    description: str = 'Get payment status of a transaction'
    args_schema: Type[BaseModel] = RetrievePaymentStatusSchema

    def __init__(self):
        super().__init__()

    def _run(self, transaction_id: str) -> str:
        return 'PENDING'

retrieve_payment_status_tool = RetrievePaymentStatus()

llm = ChatOpenAI(
    model='MIXTRAL',
    base_url='http://10.244.3.28:8080/v1',
    api_key='aboba',
    max_tokens=500,
)

llm_with_tools = llm.bind_tools([retrieve_payment_status_tool], tool_choice="auto")
llm_with_tools.invoke(input="What's the status of my transaction T1001?")

Error:

ValidationError: 1 validation error for AIMessage
invalid_tool_calls -> 0 -> args
  str type expected (type=type_error.str)

Expected behavior

I guess that tool_calls should also look like this:

...
"tool_calls": [
          {
            "id": "0",
            "type": "function",
            "function": {
              "description": null,
              "name": "retrieve_payment_status",
              "arguments": "{\"transaction_id\": \"T1001\"}"
            }
...

UPD: I guess the same goes for streaming the response. Where OpenAI gradually (token by token) fills response["tool_calls"]["function"]["arguments"] with actual arguments of the tool being called, TGI fits the entire model response into this field. Oddly enough the original problem that I stated above does not appear when using streaming.

To illustrate what I've just said, here is what I get, if I concatenate response["tool_calls"]["function"]["arguments"] across all collected response chunks for code example above:

{
  "function": {
    "_name": "retrieve_payment_status",
    "transaction_id": "T1001"
  }
}<|im_end|>

And this is what I get when using actual OpenAI API

{"transaction_id":"T1001"}
ishelaputov commented 1 month ago

Hello! I confirm, the same problem is with llama3 8b. Temporarily solved with a hack to convert valid JSON arguments to JSON string . But, I would like to adhere to the OpenAI protocol as is.

My dirty hack: method _generate OpenAI:

def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        if self.streaming:
            stream_iter = self._stream(
                messages, stop=stop, run_manager=run_manager, **kwargs
            )
            return generate_from_stream(stream_iter)
        message_dicts, params = self._create_message_dicts(messages, stop)
        params = {**params, **kwargs}

        last_message = message_dicts[-1]
        if last_message["role"] == 'tool':
            if 'tool_choice' in params:
                removed_value = params.pop('tool_choice', None)
                print(f"Из параметров удален параметр 'tool_choice' со значением '{removed_value}'")

        response = normalize_chat_complition(self.client.create(messages=message_dicts, **params))
        return self._create_chat_result(response)

My transformation (please don't beat me up for bad code :) ):

def normalize_chat_complition(response: ChatCompletion) -> Any:
    """
    Нормализация ответа от HuggingFace Text Generation Inference к формату ответа OpenAI:
        - tool_call.function.arguments приводятся к строке формата JSON,
        - для choice.message.content задается значение " ", если пришло 'null'.

    Args:
        response (ChatCompletion): Ответ от HuggingFace Text Generation Inference.

    Return (ChatCompletion):
        Нормализованный ответ от HuggingFace Text Generation Inference.
    """
    choices = []
    for choice in response.choices:
        if choice.message.tool_calls :
            tool_calls = []
            for tool_call in choice.message.tool_calls:
                tool_call.function.arguments = json.dumps(tool_call.function.arguments)
                tool_calls.append(tool_call)
            choice.message.tool_calls = tool_calls
        if not choice.message.content:
            choice.message.content = " "
        choices.append(choice)
    return ChatCompletion(
        id=response.id,
        choices=choices,
        created=response.created,
        model=response.model,
        object="chat.completion",
        system_fingerprint=response.system_fingerprint,
        usage=response.usage
        )

This code eliminates the method call loop:

 last_message = message_dicts[-1]
        if last_message["role"] == 'tool':
            if 'tool_choice' in params:
                removed_value = params.pop('tool_choice', None)
                print(f"Из параметров удален параметр 'tool_choice' со значением '{removed_value}'")
DaMagus26 commented 1 month ago

I temporarily solved it by creating a proxy-service, that captures the response from TGI model and reformats to match OpenAI API. But your solution also looks neat :)

github-actions[bot] commented 3 weeks ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.