langchain-ai / langchain-aws

Build LangChain Applications on AWS
MIT License
103 stars 81 forks source link

Issue using tools with Mistral Large model #265

Open supreetkt opened 1 week ago

supreetkt commented 1 week ago

While trying to use tools with Mistral, here is a sample code for boto3 which works fine:

import json
import boto3

session = boto3.session.Session()
bedrock_client = session.client("bedrock-runtime", region_name="us-west-2")

accept = "application/json"
contentType = "application/json"
body = json.dumps(
    {
        "messages": [{"role": "user", "content": "Tell me a joke about bears"}],
        "tools": [
            {
                "type": "function",
                "function": {
                    "name": "retrieve_animal_details",
                    "description": "Get forest animal details",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "animal_name": {
                                "type": "string",
                                "description": "Name of the animal",
                            }
                        },
                        "required": ["animal_name"],
                    },
                },
            }
        ],
        "tool_choice": "auto",
        "temperature": 0.5,
    }
)
response = bedrock_client.invoke_model(
    modelId="mistral.mistral-large-2402-v1:0", contentType=contentType, accept=accept, body=body
)

response_body = json.loads(response.get("body").read())
print("Model response using boto3:: ", response_body["choices"][0]["message"]["content"])

The request body using boto3 is:

{
    "tools": [
        {
            "type": "function",
            "function": {
                "name": "retrieve_animal_details",
                "description": "Get forest animal details",
                "parameters": {
                    "type": "object",
                    "properties": {"animal_name": {"type": "string", "description": "Name of the animal"}},
                    "required": ["animal_name"],
                },
            },
        }
    ],
    "tool_choice": "auto",
    "messages": [{"role": "user", "content": "Tell me a joke about bears"}],
    "temperature": 0.5,
}

However, while using ChatBedrock:

import boto3
from langchain_aws import ChatBedrock
from langchain_core.prompts import ChatPromptTemplate

session = boto3.session.Session()
bedrock_client = session.client("bedrock-runtime", region_name="us-west-2")

prompt = ChatPromptTemplate.from_messages(["Tell me a joke about {animal}"])
model = ChatBedrock(
    client=bedrock_client,
    model_id="mistral.mistral-large-2402-v1:0",
    model_kwargs={
        "temperature": 0.5,
        "tools": [
            {
                "type": "function",
                "function": {
                    "name": "retrieve_animal_details",
                    "description": "Get forest animal details",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "animal_name": {
                                "type": "string",
                                "description": "Name of the animal",
                            }
                        },
                        "required": ["animal_name"],
                    },
                },
            }
        ],
        "tool_choice": "auto",
    },
)

chain = prompt | model

response = chain.invoke({"animal": "bears"})
print("Model response using ChatBedrock:: ", response)

The request body is (not including accept, contentType and modelId):

{
    "tools": [
        {
            "type": "function",
            "function": {
                "name": "retrieve_animal_details",
                "description": "Get forest animal details",
                "parameters": {
                    "type": "object",
                    "properties": {"animal_name": {"type": "string", "description": "Name of the animal"}},
                    "required": ["animal_name"],
                },
            },
        }
    ],
    "tool_choice": "auto",
    "prompt": "[INST] Tell me a joke about bears [/INST]",
    "temperature": 0.5,
}

And I get the following ValidationError while using ChatBedrock:

response = chain.invoke({"animal": "bears"})
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python3.12/site-packages/langchain_core/runnables/base.py", line 3024, in invoke
    input = context.run(step.invoke, input, config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python3.12/site-packages/langchain_core/language_models/chat_models.py", line 286, in invoke
    self.generate_prompt(
  File "python3.12/site-packages/langchain_core/language_models/chat_models.py", line 786, in generate_prompt
    return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python3.12/site-packages/langchain_core/language_models/chat_models.py", line 643, in generate
    raise e
  File "python3.12/site-packages/langchain_core/language_models/chat_models.py", line 633, in generate
    self._generate_with_cache(
  File "python3.12/site-packages/langchain_core/language_models/chat_models.py", line 851, in _generate_with_cache
    result = self._generate(
             ^^^^^^^^^^^^^^^
  File "python3.12/site-packages/langchain_aws/chat_models/bedrock.py", line 561, in _generate
    completion, tool_calls, llm_output = self._prepare_input_and_invoke(
                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python3.12/site-packages/langchain_aws/llms/bedrock.py", line 842, in _prepare_input_and_invoke
    raise e
  File "python3.12/site-packages/langchain_aws/llms/bedrock.py", line 828, in _prepare_input_and_invoke
    response = self.client.invoke_model(**request_options)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python3.12/site-packages/botocore/client.py", line 569, in _api_call
    return self._make_api_call(operation_name, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python3.12/site-packages/botocore/client.py", line 1023, in _make_api_call
    raise error_class(parsed_response, operation_name)
botocore.errorfactory.ValidationException: An error occurred (ValidationException) when calling the InvokeModel operation: Validation Error                                                                 

Unfortunately, the error is not very descriptive from Bedrock, so its also difficult to debug. But the request body gives a fair indication between a boto3 and ChatBedrock request.

HyphenHook commented 5 days ago

Hello! It appears that the AWS Mistral Large models have started using Messages, unlike the older Mistral models that use Prompt strings. The current implementation of ChatBedrock will only pass Prompt strings to Mistral provider models thus causing ValidationError due to the Mistral Large Models expecting Messages instead of Prompt strings. If possible, I'd like to work on this issue and create a PR to add compatibility for ChatBedrock to use the Mistral Large chat models (along with the ability to tool call with that model).

I have a plan to address this by following the structure of the request and converting Mistral responses accordingly.

If there are additional details I should consider, please let me know. Thanks!

supreetkt commented 4 days ago

Hi @HyphenHook - I think what you're saying might also apply to this other issue I opened regarding AI21 models. Please check.

HyphenHook commented 4 days ago

Oh I see. If BedrockLLM is going to be deprecated in favor of ChatBedrockConverse then I suppose ChatBedrock would probably be deprecated as well since ChatBedrock relies on the BedrockLLM code currently. If ChatBedrock is eventually going to be deprecated would there be any need to introduce support for the "new" models (like Mistral Large)?

3coins commented 2 days ago

@HyphenHook I would recommend using ChatBedrockConverse for now until we add support in ChatBedrock. Both AI21 and Mistral Large seems to be supported by the converse API. https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html