samuelint / langchain-openai-api-bridge

A bridge to use Langchain output as an OpenAI-compatible API
MIT License
7 stars 3 forks source link

I got an error when using a graph object as an agent #38

Closed Valdanitooooo closed 1 week ago

Valdanitooooo commented 3 weeks ago

First of all, thank you for your great work in making it so easy to build the agent API !🫶

I hope to use graph as an agent to provide services, Because the react agent are not free enough.

Here is my code

Server

import operator
from typing import Annotated, Sequence, TypedDict

import uvicorn
from fastapi import FastAPI, Security
from fastapi.security import APIKeyHeader
from langchain_community.tools import DuckDuckGoSearchResults
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolMessage, AIMessage
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_openai_api_bridge.core import BaseAgentFactory
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
from langchain_openai_api_bridge.fastapi import LangchainOpenaiApiBridgeFastAPI
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolExecutor
from langgraph.prebuilt import ToolInvocation
from pydantic import BaseModel
from starlette.middleware.cors import CORSMiddleware

from modules.common.utils import llm

tools = [DuckDuckGoSearchResults()]
tool_executor = ToolExecutor(tools)

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

def should_continue(state):
    messages = state["messages"]
    last_message = messages[-1]
    if not last_message.tool_calls:
        return "end"
    else:
        return "continue"

def call_model(state):
    messages = state["messages"]
    response = llm.invoke(messages)
    return {"messages": [response]}

def call_tool(state):
    messages = state["messages"]
    last_message = messages[-1]
    tool_invocations = []
    for tool_call in last_message.tool_calls:
        action = ToolInvocation(
            tool=tool_call["name"],
            tool_input=tool_call["args"],
        )
        tool_invocations.append(action)

    responses = tool_executor.batch(tool_invocations, return_exceptions=True)
    tool_messages = [
        ToolMessage(
            content=str(response),
            name=tc["name"],
            tool_call_id=tc["id"],
        )
        for tc, response in zip(last_message.tool_calls, responses)
    ]
    return {"messages": tool_messages}

def first_model(state):
    human_input = state["messages"][-1]["content"]
    return {
        "messages": [
            AIMessage(
                content="",
                tool_calls=[
                    {
                        "name": "reflection",
                        "args": {
                            "query": human_input,
                        },
                        "id": "tool_abcd123",
                    }
                ],
            )
        ]
    }

def create_graph():
    workflow = StateGraph(AgentState)

    workflow.add_node("first_agent", first_model)

    workflow.add_node("agent", call_model)
    workflow.add_node("action", call_tool)

    workflow.add_edge(START, "first_agent")

    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "action",
            "end": END,
        },
    )

    workflow.add_edge("action", "agent")

    workflow.add_edge("first_agent", "action")
    graph = workflow.compile(debug=True).with_config(
        RunnableConfig(
            configurable={"thread_id": "1"},
            recursion_limit=10, )
    )
    return graph

class TestAgent(BaseAgentFactory):
    def create_agent(self, dto: CreateAgentDto) -> Runnable:
        return create_graph()

app = FastAPI(
    title="Langchain Agent OpenAI API Bridge",
    version="1.0",
    description="OpenAI API exposing langchain agent",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"],
)

token_key = APIKeyHeader(name="Authorization")

class Token(BaseModel):
    token: str

def get_current_token(auth_key: str = Security(token_key)):
    return auth_key

rag_bridge = LangchainOpenaiApiBridgeFastAPI(
    app=app, agent_factory_provider=lambda: TestAgent()
)
rag_bridge.bind_openai_chat_completion(prefix="/test-agent")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8888)

Client

from openai import OpenAI

if __name__ == '__main__':
    openai_client = OpenAI(
        base_url="http://localhost:8888/test-agent/openai/v1",
        api_key="xxx"
    )

    chat_completion = openai_client.chat.completions.create(
        model="agent-model",
        messages=[
            {
                "role": "user",
                "content": 'hi',
            }
        ],
    )
    print(chat_completion.choices[0].message.content)

Error

[{'step': -2}:checkpoint] State at the end of step {'step': -2}:
{'messages': []}
[0:tasks] Starting step 0 with 1 task:
- __start__ -> [{'content': 'hi', 'role': 'user'}]
ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 426, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 84, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in __call__
    await super().__call__(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/middleware/cors.py", line 85, in __call__
    await self.app(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 65, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/routing.py", line 297, in handle
    await self.app(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/routing.py", line 77, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/routing.py", line 75, in app
    await response(scope, receive, send)
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/responses.py", line 258, in __call__
    async with anyio.create_task_group() as task_group:
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 597, in __aexit__
    raise exceptions[0]
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/responses.py", line 261, in wrap
    await func()
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/starlette/responses.py", line 250, in stream_response
    async for chunk in self.body_iterator:
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_openai_api_bridge/chat_completion/http_stream_response_adapter.py", line 10, in to_str_stream
    async for chunk in chunk_stream:
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py", line 7, in ato_dict
    async for obj in async_iter:
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py", line 26, in ato_chat_completion_chunk_stream
    async for event in astream_event:
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 5290, in astream_events
    async for item in self.bound.astream_events(
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 1247, in astream_events
    async for event in event_stream:
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_core/tracers/event_stream.py", line 1005, in _astream_events_implementation_v2
    await task
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_core/tracers/event_stream.py", line 965, in consume_astream
    async for _ in event_streamer.tap_output_aiter(run_id, stream):
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_core/tracers/event_stream.py", line 181, in tap_output_aiter
    first = await py_anext(output, default=sentinel)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_core/utils/aiter.py", line 78, in anext_impl
    return await __anext__(iterator)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/pregel/__init__.py", line 1277, in astream
    _panic_or_proceed(done, inflight, loop.step, asyncio.TimeoutError)
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/pregel/__init__.py", line 1456, in _panic_or_proceed
    raise exc
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/pregel/executor.py", line 123, in done
    task.result()
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/pregel/retry.py", line 72, in arun_with_retry
    async for _ in task.proc.astream(task.input, task.config):
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 875, in astream
    yield await self.ainvoke(input, config, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/utils.py", line 114, in ainvoke
    ret = await self._acall_with_config(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/install/conda/miniconda3/envs/gradio/lib/python3.11/site-packages/langchain_core/runnables/base.py", line 1835, in _acall_with_config
    output: Output = await asyncio.create_task(coro, context=context)  # type: ignore
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/pregel/write.py", line 127, in _awrite
    values = await asyncio.gather(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/utils.py", line 111, in ainvoke
    return self.invoke(input, config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/utils.py", line 102, in invoke
    ret = context.run(self.func, input, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valdanito/workspace/original/langgraph/libs/langgraph/langgraph/graph/state.py", line 529, in _get_state_key
    raise InvalidUpdateError(f"Expected dict, got {input}")
langgraph.errors.InvalidUpdateError: Expected dict, got [{'role': 'user', 'content': 'hi'}]

My solution

Then I made modifications to the Langgraph source code, and my application runs well. https://github.com/Valdanitooooo/langgraph/commit/c2d05edcd04becf3619d025eb90f65d0966077a1

My question

I'm not sure if it's a problem with langgraph or your code, I hope to find the best solution. I hope I have expressed myself clearly. ❤️

samuelint commented 3 weeks ago

Hi @Valdanitooooo

I think the error is because the library only support a langgraph state where the properties are the same has the default react agent state.

I see 2 options:

  1. Adjust the agent state to match the default react agent.
  2. Make a change to the library to support custom inputs. That would imply to have some kind of adapter to match the react agent standard. It might be better to simply have the state to match the react agent state.

No matter which of the options solves your problem, it would be very nice of you to add an example of your use case in the « functional test » directory and README.

I can take a deeper look next week if needed.

Thanks a lot for your interest in this library and your support for it 😀

Valdanitooooo commented 3 weeks ago

it would be very nice of you to add an example of your use case in the « functional test » directory and README.

I would love to do that, but my test code needs to install the dependency langgraph@git+https://github.com/Valdanitooooo/langgraph.git@dev#subdirectory=libs/langgraph to run well, which is not constructive, so I think it would be better to update the README after this issue is solved. I can put the test code in the « functional test » directory first. 😉

samuelint commented 2 weeks ago

This seem to be a recurent issue / missunderstanding. I will soon update the library README and add a usage example on how to use a custom LangGraph agent (which does not requires external deps).

samuelint commented 1 week ago

Thanks for your contribution @Valdanitooooo ! Your pull request is now merged in the main repo.