langchain-ai / langserve

LangServe 🦜️🏓
Other
1.94k stars 217 forks source link

Serialization issues with intermediate_steps for AgentExecutor #381

Open ccurme opened 10 months ago

ccurme commented 10 months ago

I experimented with a use case in which I initialize an AgentExecutor with an agent chain that is a RemoteRunnable. i.e., the client side looks like this:

from langchain.agents import AgentExecutor, tool
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langserve import RemoteRunnable

@tool
def get_word_length(word: str) -> int:
    """Returns the length of a word."""
    return len(word)

TOOLS = [get_word_length]

remote_runnable = RemoteRunnable("http://localhost:8000/example")
agent = remote_runnable | OpenAIFunctionsAgentOutputParser()

agent_executor = AgentExecutor(agent=agent, tools=TOOLS, verbose=True)

agent_executor.invoke({"input": "how many characters are in the word quizzical"})

I ended up not needing this pattern but thought it could be useful as a way for tool execution to happen outside of the Langserve server.

Server looks like this:

from typing import Optional

from fastapi import FastAPI
from langchain.agents import tool
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.pydantic_v1 import BaseModel, Field, validator
from langchain.schema import AIMessage
from langchain_community.chat_models import ChatOpenAI
from langchain_community.tools.convert_to_openai import format_tool_to_openai_function
from langchain_core.agents import AgentActionMessageLog
from langserve import add_routes

app = FastAPI(
    title="Example server.",
    version="1.0",
)

@tool
def get_word_length(word: str) -> int:
    """Returns the length of a word."""
    return len(word)

TOOLS = [get_word_length]

def get_agent_chain():
    """Get chain."""
    llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are very powerful assistant, but can't count characters in words.",
            ),
            ("user", "{input}"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    llm_with_tools = llm.bind(functions=[format_tool_to_openai_function(t) for t in TOOLS])
    return (
        {
            "input": lambda x: x["input"],
            "agent_scratchpad": lambda x: format_to_openai_function_messages(
                x["intermediate_steps"]
            ),
        }
        | prompt
        | llm_with_tools
    )

IntermediateSteps = list[tuple[AgentActionMessageLog, Optional[str]]]

class AgentInput(BaseModel):
    input: str
    intermediate_steps: IntermediateSteps = Field()

    @validator("intermediate_steps")
    def parse_intermediate_steps(intermediate_steps: list) -> IntermediateSteps:
        """Parse intermediate steps."""
        # Message log gets parsed as list of BaseMessage
        for intermediate_step in intermediate_steps:
            message_log, _ = intermediate_step
            message_log.message_log = [
                AIMessage(
                    content=message.content, additional_kwargs=message.additional_kwargs
                )
                for message in message_log.message_log
            ]

        return intermediate_steps

agent = get_agent_chain().with_types(input_type=AgentInput)

add_routes(
    app,
    agent,
    path="/example",
)

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="localhost", port=8000)

This example works, but breaks if I remove the custom validator on intermediate_steps. Without it, the messages in the message_log get read in as BaseMessage instead of AIMessage. I get

  File "/envs/langserve/lib/python3.9/site-packages/langchain_community/adapters/openai.py", line 145, in convert_message_to_dict
    raise TypeError(f"Got unknown type {message}")
TypeError: Got unknown type content='' additional_kwargs={'function_call': {'name': 'get_word_length', 'arguments': '{\n  "word": "quizzical"\n}'}} type='ai' example=False

on the server.

eyurtsev commented 10 months ago

Thanks for the detailed report!