huggingface / huggingface_hub

The official Python client for the Huggingface Hub.
https://huggingface.co/docs/huggingface_hub
Apache License 2.0
2.05k stars 539 forks source link

Passing tool results to the LM #2606

Open anakin87 opened 5 days ago

anakin87 commented 5 days ago

Describe the bug

Let's start by thanking you for this great resource :blue_heart:

The InferenceClient supports tool calling, as explained here.

In many use cases, it is useful to pass back the tool call to the Language Model and also the tool result in a message from tool role. In this way, the LM can for example respond in a human-readable way. This is supported in HF Transformers.

When using the InferenceClient (for Serverless Inference API or TGI), I'm struggling to find a way to reproduce this desired behavior. (I mostly experimented with Mistral and Llama models supporting tool/function calling, with similar results)

@Wauplin @hanouticelina Is this supported or planned? Is there any workaround you suggest? So far, I've only tried to wrap the tool result in a message from user and this somehow works...

Probably related issue (in TGI): https://github.com/huggingface/text-generation-inference/issues/2461

Reproduction

from huggingface_hub import InferenceClient

client = InferenceClient("mistralai/Mistral-Nemo-Instruct-2407")

messages = [
    {
        "role": "system",
        "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
    },
    {
        "role": "user",
        "content": "What's the weather like in San Giustino (Italy) in Celsius?",
    },
]
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                },
                "required": ["location", "format"],
            },
        },
    }]

client.chat_completion(messages=messages, tools=tools, max_tokens=500, temperature=0.3)

# this works great and produces a similar output:
# ChatCompletionOutput(choices=[ChatCompletionOutputComplete(finish_reason='stop', index=0, message=ChatCompletionOutputMessage(role='assistant', content=None, tool_calls=[ChatCompletionOutputToolCall(function=ChatCompletionOutputFunctionDefinition(arguments={'format': 'celsius', 'location': 'San Giustino, Italy'}, name='get_current_weather', description=None), id='0', type='function')]), logprobs=None)], ...)

# TRYING TO PASS BACK TOOL CALLS AND TOOL RESULT
new_messages = [el for el in messages]
id_ = "9Ae3bDc2F"  # fake ID needed to use Mistral models

tool_call = {"name": "get_current_weather", "arguments": {"location": "San Giustino, Italy", "format": "celsius"}}
new_messages.append({"role": "assistant", "content":"", "tool_calls": [{"type": "function", "function": tool_call, "id": id_}]})
new_messages.append({"role": "tool", "name": "get_current_temperature", "content": "22.0", "tool_call_id": id_})

client.chat_completion(messages=new_messages, tools=tools, max_tokens=500, temperature=0.3)

# HfHubHTTPError: 422 Client Error: Unprocessable Entity for url: https://api-inference.huggingface.co/models/mistralai/Mistral-Nemo-Instruct-2407/v1/chat/completions (Request ID: ...)
# Template error: unknown filter: filter string is unknown (in <string>:79)

System info

- huggingface_hub version: 0.25.2
hanouticelina commented 4 days ago

Hi @anakin87, thanks a lot for reporting this issue! I managed to reproduce the bug with mistral models. However, I tried with meta-llama/Llama-3.1-8B-Instruct and HuggingFaceH4/zephyr-7b-beta (both served with TGI) and it works fine, here is the script:

from huggingface_hub import InferenceClient

client = InferenceClient("meta-llama/Llama-3.1-8B-Instruct") # or "HuggingFaceH4/zephyr-7b-beta"

messages = [
    {
        "role": "system",
        "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.",
    },
    {
        "role": "user",
        "content": "What's the weather like in San Giustino (Italy) in Celsius?",
    },
]
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                },
                "required": ["location", "format"],
            },
        },
    }
]

output = client.chat_completion(messages=messages, tools=tools, max_tokens=500, temperature=0.3)
print(output)
tool_call = {"name": "get_current_weather", "arguments": {"location": "San Giustino, Italy", "unit": "celsius"}}
messages.append({"role": "assistant", "tool_calls": [{"type": "function", "function": tool_call}], "content": "22.0"})
messages.append({"role": "tool", "name": "get_current_weather", "content": "22.0"})

output = client.chat_completion(messages=messages, tools=tools, max_tokens=500, temperature=0.3)
print(output)

I suspect an issue with mistral chat templates, I think you can open an issue in mistralai/Mistral-Nemo-Instruct-2407 and I will also report this internally and get back to you if there is any better workaround.