langchain-ai / langgraph

Build resilient language agents as graphs.
https://langchain-ai.github.io/langgraph/
MIT License
4.75k stars 722 forks source link

Graph stream labels "HumanMessage" but it is a "AiMessage" #574

Open HGInfoNancy opened 1 month ago

HGInfoNancy commented 1 month ago

Checked other resources

Example Code

workflow = StateGraph(sv.GraphState)
# Router intial
workflow.add_conditional_edges(START, gu.route_user)
workflow.add_node("rdv", gu.rdv_generate)

# Define the nodes we will cycle between
workflow.add_node("agent", gu.agent)  # agent
workflow.add_node("retrieve", gu.tool_node_retrieve())  # retrieval

workflow.add_node(
    "generate", gu.generate
)  # Generating a response after we know the documents are relevant

# Decide whether to retrieve
workflow.add_conditional_edges(
    "agent",
    # Assess agent decision
    tools_condition,
    {
        # Translate the condition outputs to nodes in our graph
        "tools": "retrieve",
        END: END,
    },
)

# Edges taken after the `action` node is called.
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)
workflow.add_edge("rdv", END)

# Compile
graph = workflow.compile()

return graph

Here is how it looks like displayed 

graph_description

Error Message and Stack Trace (if applicable)

No response

Description

Using a graph I read that:

for output in graph.stream(inputs, stream_mode="values"):
            print("__output__", output)

mades output a concatenate list of different steos through graph path.

My path graph is input -> agent -> retriever -> generate and I got something like this :

{'messages': [HumanMessage(content='my_input?', id='3c2e2af5-f38e-4471-aee3-07502e3a5494'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'bdR7xmCGm', 'function': {'name': 'my_function', 'arguments': '{"query": "new_query"}'}}]}, response_metadata={'token_usage': {'prompt_tokens': 116, 'total_tokens': 154, 'completion_tokens': 38}, 'model': 'open-mixtral-8x22b', 'finish_reason': 'tool_calls'}, id='run-1e8ddb4a-751b-4f46-9ce3-02ddad099549-0', tool_calls=[{'name': 'my_tools', 'args': {'query': 'my_query}, 'id': 'bdR7xmCGm'}]), ToolMessage(content='blobloblo', tool_call_id='bdR7xmCGm'), HumanMessage(content="blablabla", id='739239a9-b3e6-4ef4-a01a-642d4a992d17')]}

The lastest object in my list labeled as "HumanMessage" but is generated by AI. Using Langsmith I do not have this error.

This is very confusing because I cannot just look an the event steps from my graph if I want to store it.

System Info

langchain==0.1.20 langchain-community==0.0.38 langchain-core==0.2.1 langchain-mistralai==0.1.7 langchain-postgres==0.0.6 langchain-text-splitters==0.0.1 langchainhub==0.1.15 langgraph==0.0.53 langsmith==0.1.57

HGInfoNancy commented 1 month ago

Going from a class GraphState(TypedDict):

The add_messages function defines how an update should be processed

# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]

to a class GraphState(TypedDict):

The add_messages function defines how an update should be processed

# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], operator.add]

And adding a AIMessage to generate can fix this. But that doesn't explain the forced changeover Human/AI-Message with add_messages
hwchase17 commented 1 month ago

are you able to share code to exactly reproduce this? currently trying to replicate but dont know what your nodes are doing, etc

RaminZi commented 3 weeks ago

I have the exact same issue. I was following the example on Agentic RAG from langgraph (https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph), and I removed the rewrite node. So, it's either agent->end or agent->retriever->generate->end. The output of generate is appended as HumanMessage to the list of messages in the State.

Here is a reproducible example @hwchase17 :

from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=100, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorDB
vectorstore = Chroma.from_documents(
    documents=doc_splits,
    # collection_name="rag-chroma",
    embedding=OpenAIEmbeddings()
)
retriever = vectorstore.as_retriever()

from langchain.tools.retriever import create_retriever_tool

retriever_tool = create_retriever_tool(
    retriever,
    "retrieve_blog_posts",
    "Search and return information about Lilian Weng blog posts on LLM agents, prompt engineering, and adversarial attacks on LLMs.",
)

tools = [retriever_tool]

from typing import Annotated, Sequence, TypedDict

from langchain_core.messages import BaseMessage

from langgraph.graph.message import add_messages

class AgentState(TypedDict):
    # The add_messages function defines how an update should be processed
    # Default is to replace. add_messages says "append"
    messages: Annotated[Sequence[BaseMessage], add_messages]

from typing import Annotated, Literal, Sequence, TypedDict

from langchain import hub
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI

from langgraph.prebuilt import tools_condition

### Nodes

def agent(state):
    """
    Invokes the agent model to generate a response based on the current state. Given
    the question, it will decide to retrieve using the retriever tool, or simply end.

    Args:
        state (messages): The current state

    Returns:
        dict: The updated state with the agent response appended to messages
    """
    print("---CALL AGENT---")
    messages = state["messages"]
    model = ChatOpenAI(temperature=0, streaming=True, model="gpt-4-turbo")
    model = model.bind_tools(tools)
    response = model.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

def generate(state):
    """
    Generate answer

    Args:
        state (messages): The current state

    Returns:
         dict: The updated state with re-phrased question
    """
    print("---GENERATE---")
    messages = state["messages"]
    question = messages[0].content
    last_message = messages[-1]

    question = messages[0].content
    docs = last_message.content

    # Prompt
    prompt = hub.pull("rlm/rag-prompt")

    # LLM
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)

    # Post-processing
    def format_docs(docs):
        return "\n\n".join(doc.page_content for doc in docs)

    # Chain
    rag_chain = prompt | llm | StrOutputParser()

    # Run
    response = rag_chain.invoke({"context": docs, "question": question})
    return {"messages": [response]}

print("*" * 20 + "Prompt[rlm/rag-prompt]" + "*" * 20)
prompt = hub.pull("rlm/rag-prompt").pretty_print()  # Show what the prompt looks like    

from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode

# Define a new graph
workflow = StateGraph(AgentState)

# Define the nodes we will cycle between
workflow.add_node("agent", agent)  # agent
retrieve = ToolNode([retriever_tool])
workflow.add_node("retrieve", retrieve)  # retrieval

workflow.add_node(
    "generate", generate
)  # Generating a response after we know the documents are relevant
# Call agent node to decide to retrieve or not
workflow.add_edge(START, "agent")

# Decide whether to retrieve
workflow.add_conditional_edges(
    "agent",
    # Assess agent decision
    tools_condition,
    {
        # Translate the condition outputs to nodes in our graph
        "tools": "retrieve",
        END: END,
    },
)

workflow.add_edge("generate", END)
workflow.add_edge("retrieve", "generate")

# Compile

from langgraph.checkpoint.sqlite import SqliteSaver

memory = SqliteSaver.from_conn_string(":memory:")

graph = workflow.compile(checkpointer=memory)

config = {"configurable": {"thread_id": "abc"}}
import pprint

inputs = {
    "messages": [
        ("user", "According to Lilian Weng's blog, Explain what adversarial attacks on LLMs is."),
    ]
}
for output in graph.stream(inputs, config, stream_mode="values"):
    for key, value in output.items():
        pprint.pprint(f"Output from node '{key}':")
        pprint.pprint("---")
        pprint.pprint(value, indent=2, width=80, depth=None)
    pprint.pprint("\n---\n")

graph.get_state(config).values['messages']

The architecture: agent_arc

And this is the output:

[HumanMessage(content="According to Lilian Weng's blog, Explain what adversarial attacks on LLMs is.", id='b6b6bddc-2f57-46ac-be14-6df767146a45'), AIMessage(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_tmVj2XHIP1MKLShpTIYIgtZG', 'function': {'arguments': '{"query":"adversarial attacks on LLMs"}', 'name': 'retrieve_blog_posts'}, 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_calls', 'model_name': 'gpt-4-turbo-2024-04-09', 'system_fingerprint': 'fp_486730399b'}, id='run-3891bb57-ae05-45f4-9344-2d47b6d46c3f-0', tool_calls=[{'name': 'retrieve_blog_posts', 'args': {'query': 'adversarial attacks on LLMs'}, 'id': 'calltmVj2XHIP1MKLShpTIYIgtZG'}]), ToolMessage(content="Citation#\nCited as:\n\nWeng, Lilian. (Oct 2023). “Adversarial Attacks on LLMs”. Lil’Log. https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/.\n\nAdversarial Attacks on LLMs | Lil'Log\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nLil'Log\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nPosts\n\n\n\n\nArchive\n\n\n\n\nSearch\n\n\n\n\nTags\n\n\n\n\nFAQ\n\n\n\n\nemojisearch.app\n\nAdversarial attacks on classifiers have attracted more attention in the research community in the past, many in the image domain. LLMs can be used for classification too. Given an input $\mathbf{x}$ and a classifier $f(.)$, we would like to find an adversarial version of the input, denoted as $\mathbf{x}\text{adv}$, with imperceptible difference from $\mathbf{x}$, such that\n\nAdversarial Attacks on LLMs\n \nDate: October 25, 2023 | Estimated Reading Time: 33 min | Author: Lilian Weng\n\n\n \n\n\nTable of Contents\n\n\n\nBasics\n\nThreat Model\n\nClassification\n\nText Generation\n\nWhite-box vs Black-box\n\n\n\nTypes of Adversarial Attacks\n\nToken Manipulation", name='retrieve_blog_posts', id='fb7dc47a-10ca-474a-8f48-043daff3a138', tool_call_id='call_tmVj2XHIP1MKLShpTIYIgtZG'), HumanMessage(content="Adversarial attacks on LLMs involve finding an imperceptible version of an input that can fool the classifier. These attacks have been a focus in image domains and can also be applied to LLMs used for classification. Lilian Weng's blog discusses various types of adversarial attacks on LLMs.", id='6f5f6b92-e5f1-4648-b451-5a4281ff3d2d')]