Open HGInfoNancy opened 1 month ago
Going from a class GraphState(TypedDict):
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]
to a class GraphState(TypedDict):
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], operator.add]
And adding a AIMessage to generate can fix this. But that doesn't explain the forced changeover Human/AI-Message with add_messages
are you able to share code to exactly reproduce this? currently trying to replicate but dont know what your nodes are doing, etc
I have the exact same issue. I was following the example on Agentic RAG from langgraph (https://langchain-ai.github.io/langgraph/tutorials/rag/langgraph_agentic_rag/#graph), and I removed the rewrite node. So, it's either agent->end or agent->retriever->generate->end. The output of generate is appended as HumanMessage to the list of messages in the State.
Here is a reproducible example @hwchase17 :
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]
docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=100, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)
# Add to vectorDB
vectorstore = Chroma.from_documents(
documents=doc_splits,
# collection_name="rag-chroma",
embedding=OpenAIEmbeddings()
)
retriever = vectorstore.as_retriever()
from langchain.tools.retriever import create_retriever_tool
retriever_tool = create_retriever_tool(
retriever,
"retrieve_blog_posts",
"Search and return information about Lilian Weng blog posts on LLM agents, prompt engineering, and adversarial attacks on LLMs.",
)
tools = [retriever_tool]
from typing import Annotated, Sequence, TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
class AgentState(TypedDict):
# The add_messages function defines how an update should be processed
# Default is to replace. add_messages says "append"
messages: Annotated[Sequence[BaseMessage], add_messages]
from typing import Annotated, Literal, Sequence, TypedDict
from langchain import hub
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import tools_condition
### Nodes
def agent(state):
"""
Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Args:
state (messages): The current state
Returns:
dict: The updated state with the agent response appended to messages
"""
print("---CALL AGENT---")
messages = state["messages"]
model = ChatOpenAI(temperature=0, streaming=True, model="gpt-4-turbo")
model = model.bind_tools(tools)
response = model.invoke(messages)
# We return a list, because this will get added to the existing list
return {"messages": [response]}
def generate(state):
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE---")
messages = state["messages"]
question = messages[0].content
last_message = messages[-1]
question = messages[0].content
docs = last_message.content
# Prompt
prompt = hub.pull("rlm/rag-prompt")
# LLM
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)
# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# Chain
rag_chain = prompt | llm | StrOutputParser()
# Run
response = rag_chain.invoke({"context": docs, "question": question})
return {"messages": [response]}
print("*" * 20 + "Prompt[rlm/rag-prompt]" + "*" * 20)
prompt = hub.pull("rlm/rag-prompt").pretty_print() # Show what the prompt looks like
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
# Define a new graph
workflow = StateGraph(AgentState)
# Define the nodes we will cycle between
workflow.add_node("agent", agent) # agent
retrieve = ToolNode([retriever_tool])
workflow.add_node("retrieve", retrieve) # retrieval
workflow.add_node(
"generate", generate
) # Generating a response after we know the documents are relevant
# Call agent node to decide to retrieve or not
workflow.add_edge(START, "agent")
# Decide whether to retrieve
workflow.add_conditional_edges(
"agent",
# Assess agent decision
tools_condition,
{
# Translate the condition outputs to nodes in our graph
"tools": "retrieve",
END: END,
},
)
workflow.add_edge("generate", END)
workflow.add_edge("retrieve", "generate")
# Compile
from langgraph.checkpoint.sqlite import SqliteSaver
memory = SqliteSaver.from_conn_string(":memory:")
graph = workflow.compile(checkpointer=memory)
config = {"configurable": {"thread_id": "abc"}}
import pprint
inputs = {
"messages": [
("user", "According to Lilian Weng's blog, Explain what adversarial attacks on LLMs is."),
]
}
for output in graph.stream(inputs, config, stream_mode="values"):
for key, value in output.items():
pprint.pprint(f"Output from node '{key}':")
pprint.pprint("---")
pprint.pprint(value, indent=2, width=80, depth=None)
pprint.pprint("\n---\n")
graph.get_state(config).values['messages']
The architecture:
And this is the output:
[HumanMessage(content="According to Lilian Weng's blog, Explain what adversarial attacks on LLMs is.", id='b6b6bddc-2f57-46ac-be14-6df767146a45'), AIMessage(content='', additional_kwargs={'tool_calls': [{'index': 0, 'id': 'call_tmVj2XHIP1MKLShpTIYIgtZG', 'function': {'arguments': '{"query":"adversarial attacks on LLMs"}', 'name': 'retrieve_blog_posts'}, 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_calls', 'model_name': 'gpt-4-turbo-2024-04-09', 'system_fingerprint': 'fp_486730399b'}, id='run-3891bb57-ae05-45f4-9344-2d47b6d46c3f-0', tool_calls=[{'name': 'retrieve_blog_posts', 'args': {'query': 'adversarial attacks on LLMs'}, 'id': 'calltmVj2XHIP1MKLShpTIYIgtZG'}]), ToolMessage(content="Citation#\nCited as:\n\nWeng, Lilian. (Oct 2023). “Adversarial Attacks on LLMs”. Lil’Log. https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/.\n\nAdversarial Attacks on LLMs | Lil'Log\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nLil'Log\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nPosts\n\n\n\n\nArchive\n\n\n\n\nSearch\n\n\n\n\nTags\n\n\n\n\nFAQ\n\n\n\n\nemojisearch.app\n\nAdversarial attacks on classifiers have attracted more attention in the research community in the past, many in the image domain. LLMs can be used for classification too. Given an input $\mathbf{x}$ and a classifier $f(.)$, we would like to find an adversarial version of the input, denoted as $\mathbf{x}\text{adv}$, with imperceptible difference from $\mathbf{x}$, such that\n\nAdversarial Attacks on LLMs\n \nDate: October 25, 2023 | Estimated Reading Time: 33 min | Author: Lilian Weng\n\n\n \n\n\nTable of Contents\n\n\n\nBasics\n\nThreat Model\n\nClassification\n\nText Generation\n\nWhite-box vs Black-box\n\n\n\nTypes of Adversarial Attacks\n\nToken Manipulation", name='retrieve_blog_posts', id='fb7dc47a-10ca-474a-8f48-043daff3a138', tool_call_id='call_tmVj2XHIP1MKLShpTIYIgtZG'), HumanMessage(content="Adversarial attacks on LLMs involve finding an imperceptible version of an input that can fool the classifier. These attacks have been a focus in image domains and can also be applied to LLMs used for classification. Lilian Weng's blog discusses various types of adversarial attacks on LLMs.", id='6f5f6b92-e5f1-4648-b451-5a4281ff3d2d')]
Checked other resources
Example Code
Error Message and Stack Trace (if applicable)
No response
Description
Using a graph I read that:
mades output a concatenate list of different steos through graph path.
My path graph is input -> agent -> retriever -> generate and I got something like this :
{'messages': [HumanMessage(content='my_input?', id='3c2e2af5-f38e-4471-aee3-07502e3a5494'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'bdR7xmCGm', 'function': {'name': 'my_function', 'arguments': '{"query": "new_query"}'}}]}, response_metadata={'token_usage': {'prompt_tokens': 116, 'total_tokens': 154, 'completion_tokens': 38}, 'model': 'open-mixtral-8x22b', 'finish_reason': 'tool_calls'}, id='run-1e8ddb4a-751b-4f46-9ce3-02ddad099549-0', tool_calls=[{'name': 'my_tools', 'args': {'query': 'my_query}, 'id': 'bdR7xmCGm'}]), ToolMessage(content='blobloblo', tool_call_id='bdR7xmCGm'), HumanMessage(content="blablabla", id='739239a9-b3e6-4ef4-a01a-642d4a992d17')]}
The lastest object in my list labeled as "HumanMessage" but is generated by AI. Using Langsmith I do not have this error.
This is very confusing because I cannot just look an the event steps from my graph if I want to store it.
System Info
langchain==0.1.20 langchain-community==0.0.38 langchain-core==0.2.1 langchain-mistralai==0.1.7 langchain-postgres==0.0.6 langchain-text-splitters==0.0.1 langchainhub==0.1.15 langgraph==0.0.53 langsmith==0.1.57