langchain-ai / langchain-aws

Build LangChain Applications on AWS
MIT License
63 stars 47 forks source link

Bedrock Mistral prompts need to be updated to match AWS docs #45

Open skehlet opened 1 month ago

skehlet commented 1 month ago

Hello! I am using Mistral through Bedrock with LangChain and noticed, after using the model for a while with ChatMessageHistory, it would begin to hallucinate and have trouble stopping. I dug into it and it seems similar to langchain-aws#31 that the prompt getting sent to the model doesn't quite match the AWS documentation. Notably, there is no use of <s> or </s>, and the system prompt isn't put inside the first human message.

I updated langchain_aws/chat_models/bedrock.py as follows:

def convert_messages_to_prompt_mistral(messages: List[BaseMessage]) -> str:
    """
    Convert a list of messages to a prompt for mistral. The format is:
    <s>[INST] Instruction [/INST] Model answer</s>[INST] Follow-up instruction [/INST]
    Any system messages are simply prepended to the first HumanMessage.
    """
    system_messages = [msg for msg in messages if isinstance(msg, SystemMessage)]
    non_system_messages = [msg for msg in messages if not isinstance(msg, SystemMessage)]
    prompt = "<s>"
    for idx, message in enumerate(non_system_messages):
        if isinstance(message, HumanMessage):
            prompt += "[INST] "
            if idx == 0 and len(system_messages) > 0:
                system_prompt = "\n".join([msg.content for msg in system_messages])
                prompt += f"{system_prompt}\n\n"
            prompt += f"{message.content.strip()} [/INST]"
        elif isinstance(message, AIMessage):
            prompt += f"{message.content.strip()}</s>"
        else:
            raise ValueError(f"Got unknown type {message}")
    return prompt

Here is a simple program:

from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage, SystemMessage

messages = [
    SystemMessage(content="You're a helpful assistant that answers with poems."),
    HumanMessage(content="What is the purpose of model regularization?"),
]

chat = ChatBedrock(
    model_id="mistral.mixtral-8x7b-instruct-v0:1",
    model_kwargs={"max_tokens": 1000},
    region_name="us-west-2",
)

for chunk in chat.stream(messages):
    print(chunk.content, end="", flush=True)

Prompt before (obtained by adding some print statements to the _stream function here):

{"max_tokens": 1000, "prompt": "<<SYS>> You're a helpful assistant that answers with poems. <</SYS>>\n[INST] What is the purpose of model regularization? [/INST]"}

After:

{"max_tokens": 1000, "prompt": "<s>[INST] You're a helpful assistant that answers with poems.\n\nWhat is the purpose of model regularization? [/INST]"}

Sorry, this simple program doesn't demonstrate the hallucinating, the program I have is much longer and for work and unfortunately I can't share it. But it does demonstrate the prompt updates.

With this update to the prompt, the model seems to be working great and I no longer have the issues with it hallucinating or not stopping.

Note about wrapping the system prompt in \<\>/\<\>: I reviewed the mistral docs, llama_index's implementation, as well as chujiezheng's chat_templates. Note that Mistral doesn't seem to advertise any special handling for system prompts. The llama_index implementation just follows Llama's use of \<\>, but chujiezheng doesn't. I tried both ways and it didn't seem to matter, so for the above, I left it out.

I hope this is helpful. If my understanding of how <s> and </s> works isn't correct, I apologize, but it really seems necessary according to the docs, this other really helpful read (How to Prompt Mistral AI models, and Why) and I'm pretty sure nothing is adding it. Thank you for your consideration addressing this issue.

skehlet commented 1 month ago

A longer series of messages might be more helpful, and show the </s>:

messages = [
    SystemMessage(content="You're a helpful assistant that answers with super short poems. "),
    HumanMessage(content="What is your name?"),
    AIMessage(content="I am Haiku, your poetic guide."),
    HumanMessage(content="What do you do?"),
    AIMessage(content="I distill thoughts, in verse, concise,\nTo bring you wisdom, nice and precise."),
    HumanMessage(content="What is your favorite color?"),
]

yields:

{"max_tokens": 1000, "prompt": "<s>[INST] You're a helpful assistant that answers with super short poems. \n\nWhat is your name? [/INST]I am Haiku, your poetic guide.</s>[INST] What do you do? [/INST]I distill thoughts, in verse, concise,\nTo bring you wisdom, nice and precise.</s>[INST] What is your favorite color? [/INST]"}