langchain-ai / langchain

🦜🔗 Build context-aware reasoning applications
https://python.langchain.com
MIT License
92.27k stars 14.74k forks source link

BedrockChat is not using Messages API for Anthropic v3 models #18514

Closed miroslavtushev closed 6 months ago

miroslavtushev commented 6 months ago

Checked other resources

Example Code

from langchain_community.chat_models import BedrockChat
from langchain_core.messages import HumanMessage, SystemMessage

chat = BedrockChat(model_id="anthropic.claude-3-sonnet-20240229-v1:0", model_kwargs={"temperature": 0.1}, verbose=True)
messages = [
    SystemMessage(content="You are a helpful assistant that translates English to French."),
    HumanMessage(content="I love programming.")
]
chat.invoke(messages)

Error Message and Stack Trace (if applicable)

ValueError: Error raised by bedrock service: An error occurred (ValidationException) when calling the InvokeModel operation: "claude-3-sonnet-20240229" is not supported on this API. Please use the Messages API instead.

Description

Currently, the body that is prepared for model invocation uses Completions API instead of Messages API, even though you create an instance of BedrockChat. This can be seen from the source code here:

input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
def prepare_input(
        cls, provider: str, prompt: str, model_kwargs: Dict[str, Any]
    ) -> Dict[str, Any]:
        input_body = {**model_kwargs}
        if provider == "anthropic":
            input_body["prompt"] = _human_assistant_format(prompt) # here the Completions API is used instead of Messages API
        elif provider in ("ai21", "cohere", "meta"):
            input_body["prompt"] = prompt
        elif provider == "amazon":
            input_body = dict()
            input_body["inputText"] = prompt
            input_body["textGenerationConfig"] = {**model_kwargs}
        else:
            input_body["inputText"] = prompt

        if provider == "anthropic" and "max_tokens_to_sample" not in input_body:
            input_body["max_tokens_to_sample"] = 256

        return input_body

Unwinding the call stack, ultimately this function is called, which simply combines all the chat messages into a single string:

def convert_messages_to_prompt_anthropic(
    messages: List[BaseMessage],
    *,
    human_prompt: str = "\n\nHuman:",
    ai_prompt: str = "\n\nAssistant:",
) -> str:
    """Format a list of messages into a full prompt for the Anthropic model
    Args:
        messages (List[BaseMessage]): List of BaseMessage to combine.
        human_prompt (str, optional): Human prompt tag. Defaults to "\n\nHuman:".
        ai_prompt (str, optional): AI prompt tag. Defaults to "\n\nAssistant:".
    Returns:
        str: Combined string with necessary human_prompt and ai_prompt tags.
    """

    messages = messages.copy()  # don't mutate the original list
    if not isinstance(messages[-1], AIMessage):
        messages.append(AIMessage(content=""))

    text = "".join(
        _convert_one_message_to_text(message, human_prompt, ai_prompt)
        for message in messages
    )

    # trim off the trailing ' ' that might come from the "Assistant: "
    return text.rstrip()

The new Claude v3 family of models will only support Messages API, therefore none of them will work with the current version of langchain.

System Info

langchain==0.1.10 langchain-community==0.0.25 langchain-core==0.1.28 langchain-text-splitters==0.0.1

platform: AL2 python 3.12.0

3coins commented 6 months ago

Looking into this issue for Bedrock.

stephenVertex commented 6 months ago

Seeing the exact same issue. @3coins looking forward to seeing what you find.

KoStard commented 6 months ago

+1

kobibarhanin commented 6 months ago

+1

NotSoShaby commented 6 months ago

This is not an issue with langchain, this is not supported by boto3 in the regular way. They will probably add it at some point. So weirdly, this doesnt work:

import boto3
import json
brt = boto3.client(service_name='bedrock-runtime')

body = json.dumps({
    "prompt": "\n\nHuman: explain black holes to 8th graders\n\nAssistant:",
    "max_tokens_to_sample": 300,
    "temperature": 0.1,
    "top_p": 0.9,
})

modelId = '[anthropic.claude-v2](anthropic.claude-3-sonnet-20240229-v1:0)'
accept = 'application/json'
contentType = 'application/json'

response = brt.invoke_model(body=body, modelId=modelId, accept=accept, contentType=contentType)

response_body = json.loads(response.get('body').read())

# text
print(response_body.get('completion'))

but this does:

import boto3
import json

message = "Hey man"

bedrock = boto3.client(service_name="bedrock-runtime")
body = json.dumps({
    "max_tokens": 256,
    "messages": [{"role": "user", "content": message}],
    "anthropic_version": "bedrock-2023-05-31"
})

response = bedrock.invoke_model(body=body, modelId="anthropic.claude-3-sonnet-20240229-v1:0")

response_body = json.loads(response.get("body").read())
print(response_body.get("content"))
meowsick commented 6 months ago
body = json.dumps({
    "max_tokens": 256,
    "messages": [{"role": "user", "content": message}],
    "anthropic_version": "bedrock-2023-05-31"
})

The above code triggers the Anthropic message API hence it works. According to Anthropic Text Completion API is legacy so Claude 3 family won't support it. Changes are needed in langchain to migrate to calling Message API for Claude 3 models.

miroslavtushev commented 6 months ago

I monkey patch for now until this is fixed.

from langchain_community.chat_models import BedrockChat
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
from typing import List, Dict, Optional, Any
from langchain_core.callbacks import (
    CallbackManagerForLLMRun,
)
from langchain_community.llms.utils import enforce_stop_tokens
from langchain.schema.output import ChatResult, ChatGeneration
import json

class BedrockChatV3(BedrockChat):
    """A chat model that uses the Bedrock API."""

    def _format_messages(
            self,
            messages: List[BaseMessage]
    ) -> List[Dict[str, str]]:
        list_of_messages = []
        for i,message in enumerate(messages):
            if i%2==0 and not isinstance(message, HumanMessage):
                raise Exception(f"Expected to see a HumanMessage at the position {i}, but found {message.__class__}")
            elif i%2==1 and not isinstance(message, AIMessage):
                raise Exception(f"Expected to see a AIMessage at the position {i}, but found {message.__class__}")

            list_of_messages.append({ "role": "user" if isinstance(message, HumanMessage) else "assistant", 
                "content" : message.content })
        return list_of_messages

    def _prepare_input_and_invoke(
        self,
        prompt: List[BaseMessage], ###########
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        _model_kwargs = self.model_kwargs or {}

        messages = prompt
        params = {**_model_kwargs, **kwargs}
        params["anthropic_version"] = "bedrock-2023-05-31" ########
        if "max_tokens" not in params:
            params["max_tokens"] = 256
        if self._guardrails_enabled:
            params.update(self._get_guardrails_canonical())
        # assuming the first message contains instructions
        if isinstance(messages[0], SystemMessage):
            system = messages[0].content
            messages = messages[1:]
        messages = self._format_messages(messages)
        input_body = params
        input_body["system"] = system
        input_body["messages"] = messages
        body = json.dumps(input_body)
        accept = "application/json"
        contentType = "application/json"

        request_options = {
            "modelId": self.model_id,
            "accept": accept,
            "contentType": contentType,
            "body" : body
        }

        if self._guardrails_enabled:
            request_options["guardrail"] = "ENABLED"
            if self.guardrails.get("trace"):  # type: ignore[union-attr]
                request_options["trace"] = "ENABLED"

        try:
            response = self.client.invoke_model(**request_options)
            body = json.loads(response.get("body").read().decode())
            text = body['content'][0]['text']

        except Exception as e:
            raise ValueError(f"Error raised by bedrock service: {e}")

        if stop is not None:
            text = enforce_stop_tokens(text, stop)

        # Verify and raise a callback error if any intervention occurs or a signal is
        # sent from a Bedrock service,
        # such as when guardrails are triggered.
        services_trace = self._get_bedrock_services_signal(body)  # type: ignore[arg-type]

        if services_trace.get("signal") and run_manager is not None:
            run_manager.on_llm_error(
                Exception(
                    f"Error raised by bedrock service: {services_trace.get('reason')}"
                ),
                **services_trace,
            )

        return text

    # def _stream(
    #     self,
    #     messages: List[BaseMessage],
    #     stop: Optional[List[str]] = None,
    #     run_manager: Optional[CallbackManagerForLLMRun] = None,
    #     **kwargs: Any,
    # ) -> Iterator[ChatGenerationChunk]:
    #     provider = self._get_provider()
    #     prompt = ChatPromptAdapter.convert_messages_to_prompt(
    #         provider=provider, messages=messages
    #     )

    #     for chunk in self._prepare_input_and_invoke_stream(
    #         prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
    #     ):
    #         delta = chunk.text
    #         yield ChatGenerationChunk(message=AIMessageChunk(content=delta))

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        completion = ""

        params: Dict[str, Any] = {**kwargs}
        if stop:
            params["stop_sequences"] = stop

        completion = self._prepare_input_and_invoke(
            prompt=messages, stop=stop, run_manager=run_manager, **params ##########
        )

        message = AIMessage(content=completion)
        return ChatResult(generations=[ChatGeneration(message=message)])

chat = BedrockChatV3(model_id="anthropic.claude-3-sonnet-20240229-v1:0", model_kwargs={"temperature": 0.1}, verbose=True)

messages = [
    SystemMessage(content="You are a helpful assistant that translates English to French."),
    HumanMessage(content="I love programming.")
]
chat.invoke(input=messages)

AIMessage(content="J'aime la programmation.")

Streaming can be implemented similarly for those who need it

Daan-Grashoff commented 6 months ago

And if you need streaming:

from langchain_community.chat_models import BedrockChat
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage, AIMessageChunk
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain_core.callbacks import (
    CallbackManagerForLLMRun,
)
from langchain_community.llms.utils import enforce_stop_tokens
from langchain.schema.output import ChatResult, ChatGeneration, ChatGenerationChunk, GenerationChunk
from langchain.llms.bedrock import Bedrock, BaseModel, BedrockBase, LLMInputOutputAdapter
import json

def prepare_output_stream(
    provider: str, response: Any, stop: Optional[List[str]] = None
) -> Iterator[GenerationChunk]:
    print(provider, response, stop)
    stream = response.get("body")

    if not stream:
        return

    for event in stream:
        chunk = event.get("chunk")
        if chunk:
            chunk_obj = json.loads(chunk.get('bytes').decode())
            if chunk_obj['type'] == 'content_block_delta':
                text = chunk_obj['delta']['text']
                yield GenerationChunk(
                    text=text
                )

class BedrockChatV3(BedrockChat):
    """A chat model that uses the Bedrock API."""

    def _format_messages(
            self,
            messages: List[BaseMessage]
    ) -> List[Dict[str, str]]:
        list_of_messages = []
        for i,message in enumerate(messages):
            if i%2==0 and not isinstance(message, HumanMessage):
                raise Exception(f"Expected to see a HumanMessage at the position {i}, but found {message.__class__}")
            elif i%2==1 and not isinstance(message, AIMessage):
                raise Exception(f"Expected to see a AIMessage at the position {i}, but found {message.__class__}")

            list_of_messages.append({ "role": "user" if isinstance(message, HumanMessage) else "assistant", 
                "content" : message.content })
        return list_of_messages

    def _prepare_input_and_invoke(
        self,
        prompt: List[BaseMessage], ###########
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        _model_kwargs = self.model_kwargs or {}

        messages = prompt
        params = {**_model_kwargs, **kwargs}
        params["anthropic_version"] = "bedrock-2023-05-31" ########
        if "max_tokens" not in params:
            params["max_tokens"] = 256
        if self._guardrails_enabled:
            params.update(self._get_guardrails_canonical())
        # assuming the first message contains instructions
        if isinstance(messages[0], SystemMessage):
            system = messages[0].content
            messages = messages[1:]
        messages = self._format_messages(messages)
        input_body = params
        input_body["system"] = system
        input_body["messages"] = messages
        body = json.dumps(input_body)
        accept = "application/json"
        contentType = "application/json"

        request_options = {
            "modelId": self.model_id,
            "accept": accept,
            "contentType": contentType,
            "body" : body
        }

        if self._guardrails_enabled:
            request_options["guardrail"] = "ENABLED"
            if self.guardrails.get("trace"):  # type: ignore[union-attr]
                request_options["trace"] = "ENABLED"

        try:
            response = self.client.invoke_model(**request_options)
            body = json.loads(response.get("body").read().decode())
            text = body['content'][0]['text']

        except Exception as e:
            raise ValueError(f"Error raised by bedrock service: {e}")

        if stop is not None:
            text = enforce_stop_tokens(text, stop)

        # Verify and raise a callback error if any intervention occurs or a signal is
        # sent from a Bedrock service,
        # such as when guardrails are triggered.
        services_trace = self._get_bedrock_services_signal(body)  # type: ignore[arg-type]

        if services_trace.get("signal") and run_manager is not None:
            run_manager.on_llm_error(
                Exception(
                    f"Error raised by bedrock service: {services_trace.get('reason')}"
                ),
                **services_trace,
            )

        return text

    def _prepare_input_and_invoke_stream(
        self,
        prompt: List[BaseMessage], ###########
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        _model_kwargs = self.model_kwargs or {}
        provider = self._get_provider()

        if stop:
            if provider not in self.provider_stop_sequence_key_name_map:
                raise ValueError(
                    f"Stop sequence key name for {provider} is not supported."
                )

            # stop sequence from _generate() overrides
            # stop sequences in the class attribute
            _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop

        if provider == "cohere":
            _model_kwargs["stream"] = True

        _model_kwargs = self.model_kwargs or {}

        messages = prompt
        params = {**_model_kwargs, **kwargs}
        params["anthropic_version"] = "bedrock-2023-05-31" ########
        if "max_tokens" not in params:
            params["max_tokens"] = 256
        if self._guardrails_enabled:
            params.update(self._get_guardrails_canonical())
        # assuming the first message contains instructions
        if isinstance(messages[0], SystemMessage):
            system = messages[0].content
            messages = messages[1:]
        messages = self._format_messages(messages)
        input_body = params
        input_body["system"] = system
        input_body["messages"] = messages
        body = json.dumps(input_body)
        accept = "application/json"
        contentType = "application/json"

        request_options = {
            "modelId": self.model_id,
            "accept": accept,
            "contentType": contentType,
            "body" : body
        }

        if self._guardrails_enabled:
            request_options["guardrail"] = "ENABLED"
            if self.guardrails.get("trace"):  # type: ignore[union-attr]
                request_options["trace"] = "ENABLED"

        try:
            response = self.client.invoke_model_with_response_stream(**request_options)

        except Exception as e:
            raise ValueError(f"Error raised by bedrock service: {e}")

        for chunk in prepare_output_stream(
            provider, response, stop
        ):
            yield chunk
            # verify and raise callback error if any middleware intervened
            self._get_bedrock_services_signal(chunk.generation_info)  # type: ignore[arg-type]

            if run_manager is not None:
                run_manager.on_llm_new_token(chunk.text, chunk=chunk)

    def _stream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        provider = self._get_provider()

        for chunk in self._prepare_input_and_invoke_stream(
            prompt=messages, stop=stop, run_manager=run_manager, **kwargs
        ):
            delta = chunk.text
            yield ChatGenerationChunk(message=AIMessageChunk(content=delta))

    def _generate(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> ChatResult:
        completion = ""

        params: Dict[str, Any] = {**kwargs}
        if stop:
            params["stop_sequences"] = stop

        completion = self._prepare_input_and_invoke(
            prompt=messages, stop=stop, run_manager=run_manager, **params ##########
        )

        message = AIMessage(content=completion)
        return ChatResult(generations=[ChatGeneration(message=message)])

chat = BedrockChatV3(model_id="anthropic.claude-3-sonnet-20240229-v1:0", model_kwargs={"temperature": 0.1}, verbose=True, client=bedrock_client)
messages = [
    SystemMessage(content="You are a helpful assistant that translates English to French."),
    HumanMessage(content="I love programming.")
]

gen = ""
for i in chat.stream(input=messages):
    gen += i.content
    print(gen)
rpgeddam commented 6 months ago

I monkey patch for now until this is fixed.

So there are workarounds, but since it isn't fixed in langchain, shouldn't the issue be left open?

miroslavtushev commented 6 months ago

I monkey patch for now until this is fixed.

So there are workarounds, but since it isn't fixed in langchain, shouldn't the issue be left open?

It's been fixed here