langchain-ai / langchain-aws

Build LangChain Applications on AWS
MIT License
106 stars 88 forks source link

ChatBedrockConverse#stream not streaming response for model ids with cross region inference when bind_tools is used #239

Closed renjiexu-amzn closed 1 month ago

renjiexu-amzn commented 1 month ago

To reproduce, the following code will result in the same behavior as invoke; if comment out .bind_tools line, the response would be properly streamed.

from langchain_aws import ChatBedrockConverse
from langchain_core.tools import tool

@tool(response_format="content_and_artifact")
def simple_calculator(a: int, b: int):
    """Use this tool to calcuate the sum of two integers.

    Args:
        a (int): The first integer.
        b (int): The second integer.

    Returns:
        int: The sum of the two integers.
    """
    return a + b

llm = ChatBedrockConverse(
    model="us.anthropic.claude-3-sonnet-20240229-v1:0",
    temperature=0,
    top_p=1,
    max_tokens=4096,
    region_name="us-west-2"
).bind_tools(tools=[simple_calculator])

a = llm.stream(
    input=[
        ("human", "Hello"),
    ],
)

full = next(a)

for x in a:
    print(x)
    full += x

print(full)
renjiexu-amzn commented 1 month ago

Root cause is the logic to infer the provider from model/model ID doesn't support the cross-region inference profile ID properly, where the provider would be the second element after the split.

The workaround is to explicitly provide the provider value during the setup of the ChatBedrockConverse

from langchain_aws import ChatBedrockConverse
from langchain_core.tools import tool

@tool(response_format="content_and_artifact")
def simple_calculator(a: int, b: int):
    """Use this tool to calcuate the sum of two integers.

    Args:
        a (int): The first integer.
        b (int): The second integer.

    Returns:
        int: The sum of the two integers.
    """
    return a + b

llm = ChatBedrockConverse(
    model="us.anthropic.claude-3-sonnet-20240229-v1:0",
    temperature=0,
    top_p=1,
    max_tokens=4096,
    region_name="us-west-2",
    provider="anthropic"
).bind_tools(tools=[simple_calculator])

a = llm.stream(
    input=[
        ("human", "Hello"),
    ],
)

full = next(a)

for x in a:
    print(x)
    full += x

print(full)
3coins commented 1 month ago

@renjiexu-amzn Thanks for reporting this issue. The converse API has many different ways to specify a model id with a mix of arns, foundation model, inference profiles and model ids. While we can look at a long term solution to support and identify each of these formats, a short-term fix to support inference profile ids (without hard-coding regions) will be to look at how many parts the model id has. Here is a quick attempt at this formula.

def get_provider(model_id: str) -> str:
    parts = model_id.split(".")
    return parts[1] if len(parts) == 3 else parts[0]

assert "meta" == get_provider("meta.llama3-2-3b-instruct-v1:0") # mode id
assert "meta" == get_provider("us.meta.llama3-2-3b-instruct-v1:0") # inference profile id

Let me know if the above works for you, and if you want to open a PR to make the change.

An alternate solution could be to use the Bedrock API to get more info about the model, I am not sure if the Bedrock API returns the provider info for all models, so we have to verify that. This solution will also need some more consideration at calling the API only once during initialization of the chat class.

sidatcd commented 1 week ago

This inference profile model ids are validated by its regions. There is one available fro apac now.Can we include that in the validation list??