langchain-ai / langchain-aws

Build LangChain Applications on AWS
MIT License
63 stars 47 forks source link

Adds support for function calling with Anthropic models on Bedrock #37

Closed bigbernnn closed 2 months ago

bigbernnn commented 2 months ago

Workaround for Bedrock support of function calling with Anthropic models. This change adds that bind_tools function to ChatBedrock.

from bedrock import ChatBedrock

chat = ChatBedrock(
    model_id=model_id,
    model_kwargs={"temperature": 0.1},
)

class GetWeather(BaseModel):
    """Get the current weather in a given location"""

    location: str = Field(..., description="The city and state, e.g. San Francisco, CA")

llm_with_tools = chat.bind_tools([GetWeather])
llm_with_tools

messages = [
    HumanMessage(
        content="what is the weather like in San Francisco"
    )
]
ai_msg = llm_with_tools.invoke(messages)
ai_msg

The workaround is implemented similarly to its equivalent directly using Anthropic prior to the feature currently in beta.

bigbernnn commented 2 months ago

@bigbernnn Thanks for working on this change. Can you add integration tests or examples to test this.

Integration tests added for both generate and stream responses.

zhongyu09 commented 1 month ago

Hi @bigbernnn, @3coins, from my side it seems the tool calling doesn't work for langgraph that relied on tool_calls in AIMessage. In line 399 of bedrock.py, when creating AIMessage, no tool_calls is passed to additional_kwargs.

thiagotps commented 1 month ago

Hi @3coins @bigbernnn. The current implementation also seems to ignore the stop_reason and stop_sequence fields returned by the Bedrock API when using the Claude 3 models.

mhussar commented 1 month ago

has any progress been made on this. There doesn't seem to be a standard result that can be parsed to receive the response required. it seems that we would have to rely on a hack using regular expressions

ROZBEH commented 2 weeks ago

Hi, Thanks for working on this. Any progress has been made on this? Examples listed on this page(tool calling) wouldn't work with the model below:

llm = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",
    model_kwargs={"temperature": 0.0},
)

The error is something like:

{
    "name": "KeyError",
    "message": "'tool'",
    "stack": "---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[2], line 41
     33 few_shot_prompt = ChatPromptTemplate.from_messages(
     34     [
     35         (\"system\", system),
   (...)
     38     ]
     39 )
     40 chain = {\"query\": RunnablePassthrough()} | few_shot_prompt | llm_with_tools
---> 41 ai_msg = chain.invoke(\"Whats 119 times 8 minus 20\")
     42 msgs = []
     43 msgs.append(HumanMessage(\"Whats 119 times 8 minus 20\"))

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/runnables/base.py:2504, in RunnableSequence.invoke(self, input, config, **kwargs)
   2502             input = step.invoke(input, config, **kwargs)
   2503         else:
-> 2504             input = step.invoke(input, config)
   2505 # finish the root run
   2506 except BaseException as e:

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:170, in BaseChatModel.invoke(self, input, config, stop, **kwargs)
    159 def invoke(
    160     self,
    161     input: LanguageModelInput,
   (...)
    165     **kwargs: Any,
    166 ) -> BaseMessage:
    167     config = ensure_config(config)
    168     return cast(
    169         ChatGeneration,
--> 170         self.generate_prompt(
    171             [self._convert_input(input)],
    172             stop=stop,
    173             callbacks=config.get(\"callbacks\"),
    174             tags=config.get(\"tags\"),
    175             metadata=config.get(\"metadata\"),
    176             run_name=config.get(\"run_name\"),
    177             run_id=config.pop(\"run_id\", None),
    178             **kwargs,
    179         ).generations[0][0],
    180     ).message

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:599, in BaseChatModel.generate_prompt(self, prompts, stop, callbacks, **kwargs)
    591 def generate_prompt(
    592     self,
    593     prompts: List[PromptValue],
   (...)
    596     **kwargs: Any,
    597 ) -> LLMResult:
    598     prompt_messages = [p.to_messages() for p in prompts]
--> 599     return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:456, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
    454         if run_managers:
    455             run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
--> 456         raise e
    457 flattened_outputs = [
    458     LLMResult(generations=[res.generations], llm_output=res.llm_output)  # type: ignore[list-item]
    459     for res in results
    460 ]
    461 llm_output = self._combine_llm_outputs([res.llm_output for res in results])

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:446, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, run_id, **kwargs)
    443 for i, m in enumerate(messages):
    444     try:
    445         results.append(
--> 446             self._generate_with_cache(
    447                 m,
    448                 stop=stop,
    449                 run_manager=run_managers[i] if run_managers else None,
    450                 **kwargs,
    451             )
    452         )
    453     except BaseException as e:
    454         if run_managers:

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_core/language_models/chat_models.py:671, in BaseChatModel._generate_with_cache(self, messages, stop, run_manager, **kwargs)
    669 else:
    670     if inspect.signature(self._generate).parameters.get(\"run_manager\"):
--> 671         result = self._generate(
    672             messages, stop=stop, run_manager=run_manager, **kwargs
    673         )
    674     else:
    675         result = self._generate(messages, stop=stop, **kwargs)

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_aws/chat_models/bedrock.py:423, in ChatBedrock._generate(self, messages, stop, run_manager, **kwargs)
    420 params: Dict[str, Any] = {**kwargs}
    422 if provider == \"anthropic\":
--> 423     system, formatted_messages = ChatPromptAdapter.format_messages(
    424         provider, messages
    425     )
    426     if self.system_prompt_with_tools:
    427         if system:

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_aws/chat_models/bedrock.py:312, in ChatPromptAdapter.format_messages(cls, provider, messages)
    307 @classmethod
    308 def format_messages(
    309     cls, provider: str, messages: List[BaseMessage]
    310 ) -> Tuple[Optional[str], List[Dict]]:
    311     if provider == \"anthropic\":
--> 312         return _format_anthropic_messages(messages)
    314     raise NotImplementedError(
    315         f\"Provider {provider} not supported for format_messages\"
    316     )

File ~/.pyenv/versions/3.11.7/envs/3.11/lib/python3.11/site-packages/langchain_aws/chat_models/bedrock.py:228, in _format_anthropic_messages(messages)
    225     system = message.content
    226     continue
--> 228 role = _message_type_lookups[message.type]
    229 content: Union[str, List[Dict]]
    231 if not isinstance(message.content, str):
    232     # parse as dict

KeyError: 'tool'"
}
vb-rob commented 1 week ago

Try the new ChatBedrockConverse model instead (still in beta). It supports tool calling, streaming, structured outputs, etc. Git repo here.