langchain-ai / langchain

šŸ¦œšŸ”— Build context-aware reasoning applications
https://python.langchain.com
MIT License
89.64k stars 14.17k forks source link

Tool Calls with large parameters are blocking between on_chat_model_stream and on_chat_model_end #24021

Open hzeus opened 2 weeks ago

hzeus commented 2 weeks ago

Checked other resources

Example Code

import getpass

api_endpoint = getpass.getpass("API Endpoint")
api_key = getpass.getpass("API Key")

from datetime import datetime

from langchain_core.messages import HumanMessage
from langchain_openai import AzureChatOpenAI
from langgraph.graph import END, MessageGraph
from langgraph.prebuilt import ToolExecutor
from langchain.tools import tool

from langchain_openai import AzureChatOpenAI

@tool
def file_saver(text: str) -> str:
    """Persist the given string to disk
    """
    pass

model = AzureChatOpenAI(
    deployment_name="cogdep-gpt-4o",
    model_name="gpt-4o",
    azure_endpoint=api_endpoint,
    openai_api_key=api_key,
    openai_api_type="azure",
    openai_api_version="2024-05-01-preview",
    streaming=True,
    temperature=0.1
)

tools = [file_saver]
model = model.bind_tools(tools)

def get_agent_executor():
    def should_continue(messages):
        print(f"{datetime.now()}: Starting should_continue")
        return "end"

    async def call_model(messages):
        response = await model.ainvoke(messages)
        return response

    workflow = MessageGraph()

    workflow.add_node("agent", call_model)

    workflow.set_entry_point("agent")

    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "end": END,
        },
    )
    return workflow.compile()

agent_executor = get_agent_executor()

messages = [HumanMessage(content="Think of a poem with 100 verses and save it to a file. Do not print it to me first.")]

async def run():
    async for event in agent_executor.astream_events(messages, version="v1"):
        kind = event["event"]
        print(f"{datetime.now()}: Received event: {kind}")

await run()

Error Message and Stack Trace (if applicable)

This is part of the output (in this case, there is a 23s gap between `on_chat_model_stream` and `on_chat_model_end`)

(...)
2024-07-09 05:29:35.705573: Received event: on_chat_model_stream
2024-07-09 05:29:35.713679: Received event: on_chat_model_stream
2024-07-09 05:29:35.724480: Received event: on_chat_model_stream
2024-07-09 05:29:35.753143: Received event: on_chat_model_stream
2024-07-09 05:29:58.571740: Received event: on_chat_model_end
2024-07-09 05:29:58.574671: Received event: on_chain_start
2024-07-09 05:29:58.576026: Received event: on_chain_end
2024-07-09 05:29:58.577963: Received event: on_chain_start
2024-07-09 05:29:58.578214: Starting should_continue

Description

Hi!

When receiving an llm answer that leads to a tool call with a large amount of data within a parameter, we noticed that our program was blocked although we are using the async version. My guess is that the final message is built after the last message was streamed and this takes some time on the cpu? Also, is there a different approach that we could use?

Thank you very much!

System Info

System Information
------------------
> OS:  Linux
> OS Version:  langchain-ai/langgraph#1 SMP PREEMPT Thu Nov 16 10:49:20 UTC 2023
> Python Version:  3.11.6 | packaged by conda-forge | (main, Oct  3 2023, 11:57:02) [GCC 12.3.0]

Package Information
-------------------
> langchain_core: 0.2.11
> langchain: 0.2.6
> langsmith: 0.1.84
> langchain_openai: 0.1.14
> langchain_text_splitters: 0.2.2
> langgraph: 0.1.5

Packages not installed (Not Necessarily a Problem)
--------------------------------------------------
The following packages were not found:

> langserve
hinthornw commented 2 weeks ago

Thanks for flagging. I am able to reproduce the lag.

Looks like a disproportionate amount of time is being spent on just parsing the json values, which is something we can improve (I think by not doing redundant work, though haven't looked at it deeply yet). Thanks for bringing it to our attention.

Statistical profile from a previous run.

image

cprofile attached as well.

        7    0.001    0.000   39.557    5.651 python3.11/sit
e-packages/langchain_core/language_models/chat_models.py:771(_agenerate_with_cache)
        1    0.011    0.011   39.537   39.537 python3.11/sit
e-packages/langchain_core/language_models/chat_models.py:79(generate_from_stream)
     4013    0.014    0.000   39.525    0.010 langchain_core/outputs/chat_generation.py:83(__add__)
     4013    0.023    0.000   39.413    0.010 langchain_core/messages/ai.py:232(__add__)
12050/8035    0.022    0.000   39.381    0.005 pydantic/v1/main.py:332(__init__)
12050/8035    0.139    0.000   39.360    0.005 pydantic/v1/main.py:1030(validate_model)
     4016    0.008    0.000   39.300    0.010 langchain_core/messages/ai.py:78(__init__)
     4016    0.007    0.000   39.292    0.010 langchain_core/messages/base.py:57(__init__)
     4015    0.014    0.000   38.904    0.010 langchain_core/messages/ai.py:178(init_tool_calls)
     4014   11.357    0.003   38.890    0.010 langchain_core/utils/json.py:44(parse_partial_json)
  2158964    2.544    0.000   27.532    0.000 python3.11/json/__init__.py:299(loads)
  2158964    1.206    0.000   21.431    0.000 python3.11/json/decoder.py:332(decode)
  2158964   16.069    0.000   19.615    0.000 /U

profiles.tgz

Repro script:

import asyncio
import cProfile
import io
import pstats
from datetime import datetime
from typing import Literal

from langchain.tools import tool
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI

from langgraph.graph import END, MessageGraph
from pyinstrument import Profiler

@tool
def file_saver(text: str) -> str:
    """Persist the given string to disk"""
    pass

model = ChatOpenAI(
    model="gpt-4o",
    streaming=True,
    temperature=0.1,
)

tools = [file_saver]
model = model.bind_tools(tools)

def get_agent_executor():
    def should_continue(messages):
        print(f"{datetime.now()}: Starting should_continue")
        return "end"

    async def call_model(messages):
        response = await model.ainvoke(messages)
        return response

    workflow = MessageGraph()

    workflow.add_node("agent", call_model)

    workflow.set_entry_point("agent")

    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "end": END,
        },
    )
    return workflow.compile()

agent_executor = get_agent_executor()

messages = [
    HumanMessage(
        content="Think of a ballad with 300 verses and save it"
        " to a file. Do not print it to me first. At the end of the poem, sign it as 's.b. Anonymous'."
        " You must write every line, do not skip."
    )
]

async def run():
    saw_cme = False
    py_profiler = None
    cprofiler = None
    evs = None
    async for event in agent_executor.astream_events(messages, version="v1"):
        kind = event["event"]
        if saw_cme:
            continue
        if kind == "on_chat_model_stream":
            if tc := event["data"]["chunk"].additional_kwargs.get("tool_calls"):
                if "Anonymous" in str(tc[0]["function"]):
                    if py_profiler:
                        py_profiler.stop()
                    py_profiler = Profiler(interval=0.0001, async_mode="disabled")
                    py_profiler.start()
                    if cprofiler:
                        cprofiler.disable()
                    cprofiler = cProfile.Profile()
                    cprofiler.enable()

        if kind == "on_chat_model_end":
            saw_cme = True
            if py_profiler:
                py_profiler.stop()
                py_profiler.write_html("profile.html", show_all=True)
                cprofiler.disable()
                s = io.StringIO()
                ps = pstats.Stats(cprofiler, stream=s).sort_stats(
                    pstats.SortKey.CUMULATIVE
                )
                ps.print_stats()
                with open("profile.txt", "w") as f:
                    f.write(s.getvalue())
            else:
                print("No profiling data")
            print(evs)
        evs = f"{datetime.now()}: Received event: {kind}"
        if kind == "on_chat_model_end":
            print(evs)

asyncio.run(run())
hinthornw commented 2 weeks ago

Going to transfer to the langchain repo however since the underlying issue is there. Will flag to the team.